| |
| """ |
| GeoLIP Core β Back to Basics |
| ============================== |
| Conv encoder β sphere β ConstellationCore β classifier. |
| |
| Two augmented views β InfoNCE + CE + attract + CV + spread. |
| Anchor push every N batches (self-distillation across time). |
| |
| Uses constellation.py for all geometric components. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import os, time |
| from tqdm import tqdm |
| from torchvision import datasets, transforms |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| |
| |
| |
|
|
| class ConvEncoder(nn.Module): |
| """6-layer conv backbone β flat vector β project β LayerNorm.""" |
| def __init__(self, output_dim=192): |
| super().__init__() |
| self.features = nn.Sequential( |
| nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.MaxPool2d(2), |
|
|
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), |
| nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), |
| nn.MaxPool2d(2), |
|
|
| nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), |
| nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), |
| nn.MaxPool2d(2), |
|
|
| nn.AdaptiveAvgPool2d(1), |
| nn.Flatten(), |
| ) |
| self.proj = nn.Sequential( |
| nn.Linear(256, output_dim), |
| nn.LayerNorm(output_dim), |
| ) |
|
|
| def forward(self, x): |
| return self.proj(self.features(x)) |
|
|
|
|
| |
| |
| |
|
|
| class GeoLIPCore(nn.Module): |
| """Conv encoder β L2 normalize β ConstellationCore. |
| |
| The encoder is the only component that sees pixels. |
| Everything after normalization is geometric. |
| """ |
| def __init__(self, num_classes=10, output_dim=192, |
| n_anchors=64, n_comp=8, d_comp=64, |
| anchor_drop=0.15, activation='squared_relu', |
| cv_target=0.22, infonce_temp=0.07): |
| super().__init__() |
| self.output_dim = output_dim |
|
|
| self.config = {k: v for k, v in locals().items() |
| if k != 'self' and not k.startswith('_')} |
|
|
| self.encoder = ConvEncoder(output_dim) |
| self.core = ConstellationCore( |
| num_classes=num_classes, |
| dim=output_dim, |
| n_anchors=n_anchors, |
| n_comp=n_comp, |
| d_comp=d_comp, |
| anchor_drop=anchor_drop, |
| activation=activation, |
| cv_target=cv_target, |
| infonce_temp=infonce_temp, |
| ) |
|
|
| self._init_encoder_weights() |
|
|
| def _init_encoder_weights(self): |
| for m in self.encoder.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out') |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, x): |
| feat = self.encoder(x) |
| emb = F.normalize(feat, dim=-1) |
| return self.core(emb) |
|
|
| def compute_loss(self, output, targets, output_aug=None): |
| return self.core.compute_loss(output, targets, output_aug) |
|
|
| def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1): |
| return self.core.push_anchors_to_centroids(emb_buffer, label_buffer, lr) |
|
|
|
|
| |
| |
| |
|
|
| CIFAR_MEAN = (0.4914, 0.4822, 0.4465) |
| CIFAR_STD = (0.2470, 0.2435, 0.2616) |
|
|
|
|
| class TwoViewDataset(torch.utils.data.Dataset): |
| def __init__(self, base_ds, transform): |
| self.base = base_ds |
| self.transform = transform |
| def __len__(self): |
| return len(self.base) |
| def __getitem__(self, i): |
| img, label = self.base[i] |
| return self.transform(img), self.transform(img), label |
|
|
|
|
| |
| |
| |
|
|
| |
| NUM_CLASSES = 10 |
| OUTPUT_DIM = 256 |
| N_ANCHORS = 64 |
| N_COMP = 8 |
| D_COMP = 64 |
| BATCH = 256 |
| EPOCHS = 100 |
| LR = 3e-4 |
| ACTIVATION = 'squared_relu' |
|
|
| |
| PUSH_INTERVAL = 100 |
| PUSH_LR = 0.1 |
| PUSH_BUFFER_SIZE = 5000 |
|
|
| print("=" * 60) |
| print("GeoLIP Core β Conv + ConstellationCore") |
| print(f" Encoder: 6-layer conv β {OUTPUT_DIM}-d sphere") |
| print(f" Constellation: {N_ANCHORS} anchors, {N_COMP}Γ{D_COMP} patchwork") |
| print(f" Activation: {ACTIVATION}") |
| print(f" Loss: CE + InfoNCE + attract + CV(0.22) + spread") |
| print(f" Batch: {BATCH}, LR: {LR}, Epochs: {EPOCHS}") |
| print(f" Push: every {PUSH_INTERVAL} batches, lr={PUSH_LR}") |
| print(f" Device: {DEVICE}") |
| print("=" * 60) |
|
|
| aug_transform = transforms.Compose([ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(0.2, 0.2, 0.2, 0.05), |
| transforms.ToTensor(), |
| transforms.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ]) |
| val_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ]) |
|
|
| raw_train = datasets.CIFAR10(root='./data', train=True, download=True) |
| train_ds = TwoViewDataset(raw_train, aug_transform) |
| val_ds = datasets.CIFAR10(root='./data', train=False, |
| download=True, transform=val_transform) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=BATCH, shuffle=True, |
| num_workers=2, pin_memory=True, drop_last=True) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=BATCH, shuffle=False, |
| num_workers=2, pin_memory=True) |
|
|
| print(f" Train: {len(train_ds):,} Val: {len(val_ds):,}") |
|
|
| |
| model = GeoLIPCore( |
| num_classes=NUM_CLASSES, output_dim=OUTPUT_DIM, |
| n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, |
| activation=ACTIVATION, |
| ).to(DEVICE) |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| n_enc = sum(p.numel() for p in model.encoder.parameters()) |
| n_core = sum(p.numel() for p in model.core.parameters()) |
| print(f" Parameters: {n_params:,} (encoder: {n_enc:,}, core: {n_core:,})") |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=LR) |
| total_steps = len(train_loader) * EPOCHS |
| warmup_steps = len(train_loader) * 3 |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| [torch.optim.lr_scheduler.LinearLR( |
| optimizer, start_factor=0.01, total_iters=warmup_steps), |
| torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)], |
| milestones=[warmup_steps]) |
|
|
| scaler = torch.amp.GradScaler("cuda") |
| os.makedirs("checkpoints", exist_ok=True) |
| writer = SummaryWriter("runs/geolip_core") |
| best_acc = 0.0 |
| gs = 0 |
|
|
| emb_buffer = None |
| lbl_buffer = None |
| push_count = 0 |
|
|
| print(f"\n{'='*60}") |
| print(f"TRAINING β {EPOCHS} epochs") |
| print(f"{'='*60}") |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| t0 = time.time() |
| tot_loss, tot_nce_acc, tot_nearest_cos, n = 0, 0, 0, 0 |
| correct, total = 0, 0 |
|
|
| pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b") |
| for v1, v2, targets in pbar: |
| v1 = v1.to(DEVICE, non_blocking=True) |
| v2 = v2.to(DEVICE, non_blocking=True) |
| targets = targets.to(DEVICE, non_blocking=True) |
|
|
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| out1 = model(v1) |
| out2 = model(v2) |
| loss, ld = model.compute_loss(out1, targets, output_aug=out2) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
| gs += 1 |
|
|
| |
| with torch.no_grad(): |
| batch_emb = out1['embedding'].detach().float() |
| if emb_buffer is None: |
| emb_buffer = batch_emb |
| lbl_buffer = targets.detach() |
| else: |
| emb_buffer = torch.cat([emb_buffer, batch_emb])[-PUSH_BUFFER_SIZE:] |
| lbl_buffer = torch.cat([lbl_buffer, targets.detach()])[-PUSH_BUFFER_SIZE:] |
|
|
| |
| if gs % PUSH_INTERVAL == 0 and emb_buffer is not None and emb_buffer.shape[0] > 500: |
| moved = model.push_anchors_to_centroids( |
| emb_buffer, lbl_buffer, lr=PUSH_LR) |
| push_count += 1 |
| writer.add_scalar("step/anchors_moved", moved, gs) |
|
|
| preds = out1['logits'].argmax(-1) |
| correct += (preds == targets).sum().item() |
| total += targets.shape[0] |
| tot_loss += loss.item() |
| tot_nce_acc += ld.get('nce_acc', 0) |
| tot_nearest_cos += ld.get('nearest_cos', 0) |
| n += 1 |
|
|
| if n % 10 == 0: |
| pbar.set_postfix( |
| loss=f"{tot_loss/n:.4f}", |
| acc=f"{100*correct/total:.0f}%", |
| nce=f"{tot_nce_acc/n:.2f}", |
| cos=f"{ld.get('nearest_cos', 0):.3f}", |
| push=push_count, |
| ordered=True) |
|
|
| elapsed = time.time() - t0 |
| train_acc = 100 * correct / total |
|
|
| |
| model.eval() |
| vc, vt_n = 0, 0 |
| all_embs = [] |
| with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| for imgs, lbls in val_loader: |
| imgs = imgs.to(DEVICE) |
| lbls = lbls.to(DEVICE) |
| out = model(imgs) |
| vc += (out['logits'].argmax(-1) == lbls).sum().item() |
| vt_n += lbls.shape[0] |
| all_embs.append(out['embedding'].float().cpu()) |
|
|
| val_acc = 100 * vc / vt_n |
|
|
| |
| embs = torch.cat(all_embs)[:2000].to(DEVICE) |
| with torch.no_grad(): |
| v_cv = GeometricOps.cv_metric(embs, n_samples=200) |
| diag = GeometricOps.diagnostics(model.core.constellation, embs) |
|
|
| writer.add_scalar("epoch/train_acc", train_acc, epoch + 1) |
| writer.add_scalar("epoch/val_acc", val_acc, epoch + 1) |
| writer.add_scalar("epoch/val_cv", v_cv, epoch + 1) |
| writer.add_scalar("epoch/anchors", diag['n_active'], epoch + 1) |
| writer.add_scalar("epoch/nearest_cos", tot_nearest_cos / n, epoch + 1) |
| writer.add_scalar("epoch/push_count", push_count, epoch + 1) |
|
|
| mk = "" |
| if val_acc > best_acc: |
| best_acc = val_acc |
| torch.save({ |
| "state_dict": model.state_dict(), |
| "config": model.config, |
| "epoch": epoch + 1, |
| "val_acc": val_acc, |
| }, "checkpoints/geolip_core_best.pt") |
| mk = " β
" |
|
|
| nce_m = tot_nce_acc / n |
| cos_m = tot_nearest_cos / n |
| cv_band = "β" if 0.18 <= v_cv <= 0.25 else "β" |
| print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% " |
| f"loss={tot_loss/n:.4f} nce={nce_m:.2f} cos={cos_m:.3f} " |
| f"cv={v_cv:.4f}({cv_band}) anch={diag['n_active']}/{N_ANCHORS} " |
| f"push={push_count} ({elapsed:.0f}s){mk}") |
|
|
| writer.close() |
| print(f"\n Best val accuracy: {best_acc:.1f}%") |
| print(f" Parameters: {n_params:,}") |
| print(f"\n{'='*60}") |
| print("DONE") |
| print(f"{'='*60}") |