""" SVAE v2 Conduit Trainer — Prototype ===================================== Train PatchSVAEv2 from random init on noise. The decoder MUST reconstruct from decomposed spectral + conduit bundles. No M_hat shortcut. Every conduit element is load-bearing. Readouts per epoch: Standard: MSE, S profile, erank, s_delta, CV Conduit: friction stats, settle distribution, char_coeff profile, per-mode reconstruction contribution Should converge rapidly or fail — we'll know within 10 epochs. Usage: python train_v2_conduit.py """ import os import math import time import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm try: from google.colab import userdata os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') from huggingface_hub import login login(token=os.environ["HF_TOKEN"]) except Exception: pass from geolip_svae.model import cv_of, extract_patches, stitch_patches from geolip_svae.model_v2 import PatchSVAEv2 # ═══════════════════════════════════════════════════════════════ # CONFIG # ═══════════════════════════════════════════════════════════════ HF_REPO = 'AbstractPhil/geolip-SVAE' VERSION = 'version2_v2_conduit_proto_2' LOCAL_DIR = f'/content/{VERSION}_checkpoints' LOG_PATH = os.path.join(LOCAL_DIR, 'training_log.json') CFG = dict( # Architecture (inherited from Fresnel v50) V=16, D=4, ps=4, hidden=384, depth=4, n_cross=2, stage_hidden=128, stage_V=64, # Training img_size=64, batch_size=256, lr=3e-4, epochs=50, ds_size=1280000, val_size=10000, # CV soft hand target_cv=0.2915, cv_weight=0.3, boost=0.5, sigma=0.15, # Checkpointing save_every=5, val_per_type_every=5, ) # ═══════════════════════════════════════════════════════════════ # NOISE DATASET (16 types, same as Freckles) # ═══════════════════════════════════════════════════════════════ NOISE_NAMES = { 0: 'gaussian', 1: 'uniform', 2: 'uniform_sc', 3: 'poisson', 4: 'pink', 5: 'brown', 6: 'salt_pepper', 7: 'sparse', 8: 'block', 9: 'gradient', 10: 'checker', 11: 'mixed', 12: 'structural', 13: 'cauchy', 14: 'exponential', 15: 'laplace', } def _pink(shape): w = torch.randn(shape) S = torch.fft.rfft2(w) h, ww = shape[-2], shape[-1] fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1) fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1) return torch.fft.irfft2(S / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8), s=(h, ww)) def _brown(shape): w = torch.randn(shape) S = torch.fft.rfft2(w) h, ww = shape[-2], shape[-1] fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1) fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1) return torch.fft.irfft2(S / (fx**2 + fy**2).clamp(min=1e-8), s=(h, ww)) def _gen_noise(noise_type, s, rng): if noise_type == 0: return torch.randn(3, s, s) elif noise_type == 1: return torch.rand(3, s, s) * 2 - 1 elif noise_type == 2: return (torch.rand(3, s, s) - 0.5) * 4 elif noise_type == 3: lam = rng.uniform(0.5, 20.0) return torch.poisson(torch.full((3, s, s), lam)) / lam - 1.0 elif noise_type == 4: img = _pink((3, s, s)); return img / (img.std() + 1e-8) elif noise_type == 5: img = _brown((3, s, s)); return img / (img.std() + 1e-8) elif noise_type == 6: return torch.where(torch.rand(3, s, s) > 0.5, torch.ones(3, s, s) * 2, -torch.ones(3, s, s) * 2) + torch.randn(3, s, s) * 0.1 elif noise_type == 7: return torch.randn(3, s, s) * (torch.rand(3, s, s) > 0.9).float() * 3 elif noise_type == 8: b = rng.randint(2, max(3, s // 2)) sm = torch.randn(3, s // b + 1, s // b + 1) return F.interpolate(sm.unsqueeze(0), size=s, mode='nearest').squeeze(0) elif noise_type == 9: gy = torch.linspace(-2, 2, s).unsqueeze(1).expand(s, s) gx = torch.linspace(-2, 2, s).unsqueeze(0).expand(s, s) a = rng.uniform(0, 2 * math.pi) return (math.cos(a) * gx + math.sin(a) * gy).unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.5 elif noise_type == 10: cs = rng.randint(2, max(3, s // 2)) cy = torch.arange(s) // cs; cx = torch.arange(s) // cs return ((cy.unsqueeze(1) + cx.unsqueeze(0)) % 2).float().unsqueeze(0).expand(3, -1, -1) * 2 - 1 + torch.randn(3, s, s) * 0.3 elif noise_type == 11: alpha = rng.uniform(0.2, 0.8) return alpha * torch.randn(3, s, s) + (1 - alpha) * (torch.rand(3, s, s) * 2 - 1) elif noise_type == 12: img = torch.zeros(3, s, s); h2 = s // 2 img[:, :h2, :h2] = torch.randn(3, h2, h2) img[:, :h2, h2:] = torch.rand(3, h2, h2) * 2 - 1 img[:, h2:, :h2] = _pink((3, h2, h2)) / 2 img[:, h2:, h2:] = torch.where(torch.rand(3, h2, h2) > 0.5, torch.ones(3, h2, h2), -torch.ones(3, h2, h2)) return img elif noise_type == 13: return torch.tan(math.pi * (torch.rand(3, s, s) - 0.5)).clamp(-3, 3) elif noise_type == 14: return torch.empty(3, s, s).exponential_(1.0) - 1.0 elif noise_type == 15: u = torch.rand(3, s, s) - 0.5 return -torch.sign(u) * torch.log1p(-2 * u.abs()) return torch.randn(3, s, s) class OmegaNoiseDataset(torch.utils.data.Dataset): def __init__(self, size=1280000, img_size=64): self.size = size self.img_size = img_size self._rng = np.random.RandomState(42) self._call_count = 0 def __len__(self): return self.size def __getitem__(self, idx): self._call_count += 1 if self._call_count % 1000 == 0: self._rng = np.random.RandomState(int.from_bytes(os.urandom(4), 'big')) torch.manual_seed(int.from_bytes(os.urandom(4), 'big')) noise_type = idx % 16 img = _gen_noise(noise_type, self.img_size, self._rng).clamp(-4, 4) return img.float(), noise_type def eval_per_type(model, img_size, device, n_per=32): rng = np.random.RandomState(99) model.eval() results = {} with torch.no_grad(): for t in range(16): imgs = torch.stack([_gen_noise(t, img_size, rng).clamp(-4, 4) for _ in range(n_per)]).to(device) out = model(imgs) results[t] = F.mse_loss(out['recon'], imgs).item() return results # ═══════════════════════════════════════════════════════════════ # CONDUIT READOUTS # ═══════════════════════════════════════════════════════════════ def conduit_readout(model, images, device): """Extract and summarize conduit telemetry from a batch.""" model.eval() with torch.no_grad(): out = model(images.to(device)) packet = model.last_conduit_packet S = out['svd']['S_orig'] # (B, N, D) B, N, D = S.shape friction = packet.friction.reshape(B, N, D) settle = packet.settle.reshape(B, N, D) char_coeffs = packet.char_coeffs.reshape(B, N, D) ext_order = packet.extraction_order.reshape(B, N, D) refine_res = packet.refinement_residual.reshape(B, N) # Log-friction for readable stats log_fric = torch.log1p(friction) stats = { 'S_mean': S.mean(dim=(0, 1)).cpu().tolist(), 'S_std': S.std(dim=(0, 1)).cpu().tolist(), 'friction_mean': friction.mean().item(), 'friction_max': friction.max().item(), 'friction_std': friction.std().item(), 'log_fric_mean': log_fric.mean(dim=(0, 1)).cpu().tolist(), 'log_fric_std': log_fric.std(dim=(0, 1)).cpu().tolist(), 'settle_mean': settle.mean(dim=(0, 1)).cpu().tolist(), 'settle_frac_gt2': (settle > 2).float().mean().item(), 'char_coeffs_mean': char_coeffs.mean(dim=(0, 1)).cpu().tolist(), 'refine_res_mean': refine_res.mean().item(), 'refine_res_max': refine_res.max().item(), } # Per-mode friction spatial CV for d in range(D): per_img = friction[:, :, d].reshape(B, -1) cvs = per_img.std(dim=1) / (per_img.mean(dim=1) + 1e-8) stats[f'friction_spatial_cv_mode{d}'] = cvs.mean().item() return stats def print_conduit_readout(stats, D=4): """Pretty-print conduit telemetry.""" print(f" S: [{', '.join(f'{v:.3f}' for v in stats['S_mean'])}]") print(f" S_std: [{', '.join(f'{v:.4f}' for v in stats['S_std'])}]") print(f" log_fric: [{', '.join(f'{v:.3f}' for v in stats['log_fric_mean'])}] " f"± [{', '.join(f'{v:.3f}' for v in stats['log_fric_std'])}]") print(f" fric_raw: mean={stats['friction_mean']:.1f} max={stats['friction_max']:.0f}") print(f" settle: [{', '.join(f'{v:.2f}' for v in stats['settle_mean'])}] " f"(>{2}: {stats['settle_frac_gt2']:.1%})") print(f" char_c: [{', '.join(f'{v:.4f}' for v in stats['char_coeffs_mean'])}]") print(f" refine: mean={stats['refine_res_mean']:.2e} max={stats['refine_res_max']:.2e}") spatial_cvs = [stats.get(f'friction_spatial_cv_mode{d}', 0) for d in range(D)] print(f" fric_cv: [{', '.join(f'{v:.4f}' for v in spatial_cvs)}]") # ═══════════════════════════════════════════════════════════════ # LOGGING # ═══════════════════════════════════════════════════════════════ def load_log(): if os.path.exists(LOG_PATH): with open(LOG_PATH) as f: return json.load(f) return {'version': VERSION, 'entries': []} def save_log(log): os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True) with open(LOG_PATH, 'w') as f: json.dump(log, f, indent=2) # ═══════════════════════════════════════════════════════════════ # SAVE & PUSH # ═══════════════════════════════════════════════════════════════ def save_checkpoint(model, opt, sched, epoch, val_mse, log, path, is_best=False): os.makedirs(os.path.dirname(path), exist_ok=True) ckpt = { 'config': { 'V': CFG['V'], 'D': CFG['D'], 'patch_size': CFG['ps'], 'hidden': CFG['hidden'], 'depth': CFG['depth'], 'n_cross_layers': CFG['n_cross'], 'stage_hidden': CFG.get('stage_hidden', 128), 'stage_V': CFG.get('stage_V', 16), 'img_size': CFG['img_size'], 'model_type': 'v2', }, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': opt.state_dict(), 'scheduler_state_dict': sched.state_dict(), 'epoch': epoch, 'val_mse': val_mse, } torch.save(ckpt, path) size_mb = os.path.getsize(path) / (1024 * 1024) print(f" 💾 {path} ({size_mb:.1f}MB, ep{epoch}, MSE={val_mse:.6f})") try: from huggingface_hub import HfApi api = HfApi() api.upload_file( path_or_fileobj=path, path_in_repo=f'{VERSION}/checkpoints/{os.path.basename(path)}', repo_id=HF_REPO, repo_type='model', commit_message=f'{VERSION} ep{epoch} mse={val_mse:.6f}') if is_best: api.upload_file( path_or_fileobj=path, path_in_repo=f'{VERSION}/checkpoints/best.pt', repo_id=HF_REPO, repo_type='model', commit_message=f'{VERSION} BEST ep{epoch} mse={val_mse:.6f}') save_log(log) api.upload_file( path_or_fileobj=LOG_PATH, path_in_repo=f'{VERSION}/training_log.json', repo_id=HF_REPO, repo_type='model', commit_message=f'{VERSION} log ep{epoch}') print(f" ☁️ Pushed ep{epoch}") except Exception as e: print(f" ⚠️ Push failed: {e}") # ═══════════════════════════════════════════════════════════════ # TRAINING # ═══════════════════════════════════════════════════════════════ def train(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') os.makedirs(LOCAL_DIR, exist_ok=True) print("\n" + "=" * 70) print(f"SVAE v2 CONDUIT TRAINER — {VERSION}") print("=" * 70) # ── Fresh v2 model from random init ── D = CFG['D'] model = PatchSVAEv2( V=CFG['V'], D=D, ps=CFG['ps'], hidden=CFG['hidden'], depth=CFG['depth'], n_cross=CFG['n_cross'], stage_hidden=CFG.get('stage_hidden', 128), stage_V=CFG.get('stage_V', 16), ).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"\n Fresh PatchSVAEv2 from random init") print(f" Total params: {n_params:,}") # ── Data ── print(f"\n Dataset: 16 noise types, {CFG['ds_size']:,} samples/epoch") print(f" Image size: {CFG['img_size']}×{CFG['img_size']}") print(f" Batch size: {CFG['batch_size']}") train_ds = OmegaNoiseDataset(size=CFG['ds_size'], img_size=CFG['img_size']) val_ds = OmegaNoiseDataset(size=CFG['val_size'], img_size=CFG['img_size']) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=CFG['batch_size'], shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=CFG['batch_size'], shuffle=False, num_workers=4, pin_memory=True) # ── Optimizer (all params — full model training) ── opt = torch.optim.Adam(model.parameters(), lr=CFG['lr']) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=CFG['epochs']) # CV soft hand target_cv = CFG['target_cv'] cv_weight = CFG['cv_weight'] boost = CFG['boost'] sigma = CFG['sigma'] # Log log = load_log() best_mse = float('inf') # ── Initial conduit readout ── print(f"\n Initial conduit profile:") sample_batch = next(iter(val_loader))[0][:64] init_stats = conduit_readout(model, sample_batch, device) print_conduit_readout(init_stats, D) # ── Initial MSE (will be terrible — decoder is random) ── model.eval() init_mse = 0 init_n = 0 with torch.no_grad(): for imgs, _ in val_loader: imgs = imgs.to(device) out = model(imgs) init_mse += F.mse_loss(out['recon'], imgs).item() * len(imgs) init_n += len(imgs) if init_n >= 2560: break init_mse /= init_n print(f"\n Initial MSE (random decoder): {init_mse:.4f}") print("=" * 70) # ═══════════════════════════════════════════════════════════ # TRAINING LOOP # ═══════════════════════════════════════════════════════════ for epoch in range(1, CFG['epochs'] + 1): model.train() total_loss, total_recon, n = 0, 0, 0 last_cv = target_cv t0 = time.time() pbar = tqdm(train_loader, desc=f"Ep {epoch}/{CFG['epochs']}", bar_format='{l_bar}{bar:20}{r_bar}') for batch_idx, (images, _) in enumerate(pbar): images = images.to(device) opt.zero_grad() out = model(images) recon_loss = F.mse_loss(out['recon'], images) # CV soft hand with torch.no_grad(): if batch_idx % 50 == 0: cur_cv = cv_of(out['svd']['M'][0, 0]) if cur_cv > 0: last_cv = cur_cv delta = last_cv - target_cv prox = math.exp(-delta**2 / (2 * sigma**2)) recon_w = 1.0 + boost * prox cv_pen = cv_weight * (1.0 - prox) loss = recon_w * recon_loss + cv_pen * (last_cv - target_cv)**2 loss.backward() torch.nn.utils.clip_grad_norm_( model.cross_attn.parameters(), max_norm=0.5) opt.step() total_loss += loss.item() * len(images) total_recon += recon_loss.item() * len(images) n += len(images) pbar.set_postfix_str(f"mse={recon_loss.item():.4f} cv={last_cv:.3f}") sched.step() epoch_time = time.time() - t0 # ── Validation ── model.eval() val_mse, val_n = 0, 0 with torch.no_grad(): for imgs, _ in val_loader: imgs = imgs.to(device) out = model(imgs) val_mse += F.mse_loss(out['recon'], imgs).item() * len(imgs) val_n += len(imgs) val_mse /= val_n # ── Geometry snapshot ── with torch.no_grad(): sample = next(iter(val_loader))[0][:64].to(device) out = model(sample) S_mean = out['svd']['S_orig'].mean(dim=(0, 1)) S_coord = out['svd']['S'].mean(dim=(0, 1)) erank = model.effective_rank( out['svd']['S'].reshape(-1, D)).mean().item() s_delta = model.s_delta(out['svd']['S_orig'], out['svd']['S']) # ── Conduit readout ── cond_stats = conduit_readout(model, sample.cpu(), device) # ── Print ── is_best = val_mse < best_mse if is_best: best_mse = val_mse print(f"\n ep{epoch:3d} | recon={total_recon/n:.4f} val={val_mse:.4f} " f"{'★ BEST' if is_best else ''} | " f"er={erank:.2f} Sd={s_delta:.4f} cv={last_cv:.3f} | {epoch_time:.0f}s") print_conduit_readout(cond_stats, D) # ── Per-type eval ── if epoch % CFG['val_per_type_every'] == 0 or epoch <= 3: type_mse = eval_per_type(model, CFG['img_size'], device) type_str = " ".join( f"{NOISE_NAMES[t][:4]}={v:.3f}" for t, v in sorted(type_mse.items())) print(f" types: {type_str}") # ── Log entry ── log['entries'].append({ 'epoch': epoch, 'train_mse': total_recon / n, 'val_mse': val_mse, 'cv': last_cv, 'erank': erank, 's_delta': s_delta, 'S_mean': S_mean.cpu().tolist(), 'conduit': cond_stats, 'epoch_time': epoch_time, 'lr': opt.param_groups[0]['lr'], }) # ── Checkpoint ── if is_best: save_checkpoint(model, opt, sched, epoch, val_mse, log, os.path.join(LOCAL_DIR, 'best.pt'), is_best=True) if epoch % CFG['save_every'] == 0: save_checkpoint(model, opt, sched, epoch, val_mse, log, os.path.join(LOCAL_DIR, f'epoch_{epoch:04d}.pt')) # ═══════════════════════════════════════════════════════════ # DONE # ═══════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(f"v2 CONDUIT TRAINING COMPLETE — {VERSION}") print(f" Best MSE: {best_mse:.6f}") print(f" Epochs: {CFG['epochs']}") print(f" Params: {n_params:,}") print(f"{'=' * 70}") return model if __name__ == "__main__": torch.set_float32_matmul_precision('high') train()