| """ |
| SVAE v2 Conduit Trainer β Prototype |
| ===================================== |
| Train PatchSVAEv2 from random init on noise. |
| The decoder MUST reconstruct from decomposed spectral + conduit bundles. |
| No M_hat shortcut. Every conduit element is load-bearing. |
| |
| Readouts per epoch: |
| Standard: MSE, S profile, erank, s_delta, CV |
| Conduit: friction stats, settle distribution, char_coeff profile, |
| per-mode reconstruction contribution |
| |
| Should converge rapidly or fail β we'll know within 10 epochs. |
| |
| Usage: |
| python train_v2_conduit.py |
| """ |
|
|
| import os |
| import math |
| import time |
| import json |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| try: |
| from google.colab import userdata |
| os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') |
| from huggingface_hub import login |
| login(token=os.environ["HF_TOKEN"]) |
| except Exception: |
| pass |
|
|
| from geolip_svae.model import cv_of, extract_patches, stitch_patches |
| from geolip_svae.model_v2 import PatchSVAEv2 |
|
|
|
|
| |
| |
| |
|
|
| HF_REPO = 'AbstractPhil/geolip-SVAE' |
| VERSION = 'version2_v2_conduit_proto_2' |
| LOCAL_DIR = f'/content/{VERSION}_checkpoints' |
| LOG_PATH = os.path.join(LOCAL_DIR, 'training_log.json') |
|
|
| CFG = dict( |
| |
| V=16, D=4, ps=4, hidden=384, depth=4, n_cross=2, |
| stage_hidden=128, stage_V=64, |
|
|
| |
| img_size=64, |
| batch_size=256, |
| lr=3e-4, |
| epochs=50, |
| ds_size=1280000, |
| val_size=10000, |
|
|
| |
| target_cv=0.2915, |
| cv_weight=0.3, |
| boost=0.5, |
| sigma=0.15, |
|
|
| |
| save_every=5, |
| val_per_type_every=5, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| NOISE_NAMES = { |
| 0: 'gaussian', 1: 'uniform', 2: 'uniform_sc', 3: 'poisson', |
| 4: 'pink', 5: 'brown', 6: 'salt_pepper', 7: 'sparse', |
| 8: 'block', 9: 'gradient', 10: 'checker', 11: 'mixed', |
| 12: 'structural', 13: 'cauchy', 14: 'exponential', 15: 'laplace', |
| } |
|
|
|
|
| def _pink(shape): |
| w = torch.randn(shape) |
| S = torch.fft.rfft2(w) |
| h, ww = shape[-2], shape[-1] |
| fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1) |
| fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1) |
| return torch.fft.irfft2(S / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8), s=(h, ww)) |
|
|
|
|
| def _brown(shape): |
| w = torch.randn(shape) |
| S = torch.fft.rfft2(w) |
| h, ww = shape[-2], shape[-1] |
| fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1) |
| fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1) |
| return torch.fft.irfft2(S / (fx**2 + fy**2).clamp(min=1e-8), s=(h, ww)) |
|
|
|
|
| def _gen_noise(noise_type, s, rng): |
| if noise_type == 0: return torch.randn(3, s, s) |
| elif noise_type == 1: return torch.rand(3, s, s) * 2 - 1 |
| elif noise_type == 2: return (torch.rand(3, s, s) - 0.5) * 4 |
| elif noise_type == 3: |
| lam = rng.uniform(0.5, 20.0) |
| return torch.poisson(torch.full((3, s, s), lam)) / lam - 1.0 |
| elif noise_type == 4: |
| img = _pink((3, s, s)); return img / (img.std() + 1e-8) |
| elif noise_type == 5: |
| img = _brown((3, s, s)); return img / (img.std() + 1e-8) |
| elif noise_type == 6: |
| return torch.where(torch.rand(3, s, s) > 0.5, |
| torch.ones(3, s, s) * 2, -torch.ones(3, s, s) * 2) + torch.randn(3, s, s) * 0.1 |
| elif noise_type == 7: |
| return torch.randn(3, s, s) * (torch.rand(3, s, s) > 0.9).float() * 3 |
| elif noise_type == 8: |
| b = rng.randint(2, max(3, s // 2)) |
| sm = torch.randn(3, s // b + 1, s // b + 1) |
| return F.interpolate(sm.unsqueeze(0), size=s, mode='nearest').squeeze(0) |
| elif noise_type == 9: |
| gy = torch.linspace(-2, 2, s).unsqueeze(1).expand(s, s) |
| gx = torch.linspace(-2, 2, s).unsqueeze(0).expand(s, s) |
| a = rng.uniform(0, 2 * math.pi) |
| return (math.cos(a) * gx + math.sin(a) * gy).unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.5 |
| elif noise_type == 10: |
| cs = rng.randint(2, max(3, s // 2)) |
| cy = torch.arange(s) // cs; cx = torch.arange(s) // cs |
| return ((cy.unsqueeze(1) + cx.unsqueeze(0)) % 2).float().unsqueeze(0).expand(3, -1, -1) * 2 - 1 + torch.randn(3, s, s) * 0.3 |
| elif noise_type == 11: |
| alpha = rng.uniform(0.2, 0.8) |
| return alpha * torch.randn(3, s, s) + (1 - alpha) * (torch.rand(3, s, s) * 2 - 1) |
| elif noise_type == 12: |
| img = torch.zeros(3, s, s); h2 = s // 2 |
| img[:, :h2, :h2] = torch.randn(3, h2, h2) |
| img[:, :h2, h2:] = torch.rand(3, h2, h2) * 2 - 1 |
| img[:, h2:, :h2] = _pink((3, h2, h2)) / 2 |
| img[:, h2:, h2:] = torch.where(torch.rand(3, h2, h2) > 0.5, |
| torch.ones(3, h2, h2), -torch.ones(3, h2, h2)) |
| return img |
| elif noise_type == 13: |
| return torch.tan(math.pi * (torch.rand(3, s, s) - 0.5)).clamp(-3, 3) |
| elif noise_type == 14: |
| return torch.empty(3, s, s).exponential_(1.0) - 1.0 |
| elif noise_type == 15: |
| u = torch.rand(3, s, s) - 0.5 |
| return -torch.sign(u) * torch.log1p(-2 * u.abs()) |
| return torch.randn(3, s, s) |
|
|
|
|
| class OmegaNoiseDataset(torch.utils.data.Dataset): |
| def __init__(self, size=1280000, img_size=64): |
| self.size = size |
| self.img_size = img_size |
| self._rng = np.random.RandomState(42) |
| self._call_count = 0 |
| def __len__(self): |
| return self.size |
| def __getitem__(self, idx): |
| self._call_count += 1 |
| if self._call_count % 1000 == 0: |
| self._rng = np.random.RandomState(int.from_bytes(os.urandom(4), 'big')) |
| torch.manual_seed(int.from_bytes(os.urandom(4), 'big')) |
| noise_type = idx % 16 |
| img = _gen_noise(noise_type, self.img_size, self._rng).clamp(-4, 4) |
| return img.float(), noise_type |
|
|
|
|
| def eval_per_type(model, img_size, device, n_per=32): |
| rng = np.random.RandomState(99) |
| model.eval() |
| results = {} |
| with torch.no_grad(): |
| for t in range(16): |
| imgs = torch.stack([_gen_noise(t, img_size, rng).clamp(-4, 4) |
| for _ in range(n_per)]).to(device) |
| out = model(imgs) |
| results[t] = F.mse_loss(out['recon'], imgs).item() |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def conduit_readout(model, images, device): |
| """Extract and summarize conduit telemetry from a batch.""" |
| model.eval() |
| with torch.no_grad(): |
| out = model(images.to(device)) |
| packet = model.last_conduit_packet |
|
|
| S = out['svd']['S_orig'] |
| B, N, D = S.shape |
|
|
| friction = packet.friction.reshape(B, N, D) |
| settle = packet.settle.reshape(B, N, D) |
| char_coeffs = packet.char_coeffs.reshape(B, N, D) |
| ext_order = packet.extraction_order.reshape(B, N, D) |
| refine_res = packet.refinement_residual.reshape(B, N) |
|
|
| |
| log_fric = torch.log1p(friction) |
|
|
| stats = { |
| 'S_mean': S.mean(dim=(0, 1)).cpu().tolist(), |
| 'S_std': S.std(dim=(0, 1)).cpu().tolist(), |
| 'friction_mean': friction.mean().item(), |
| 'friction_max': friction.max().item(), |
| 'friction_std': friction.std().item(), |
| 'log_fric_mean': log_fric.mean(dim=(0, 1)).cpu().tolist(), |
| 'log_fric_std': log_fric.std(dim=(0, 1)).cpu().tolist(), |
| 'settle_mean': settle.mean(dim=(0, 1)).cpu().tolist(), |
| 'settle_frac_gt2': (settle > 2).float().mean().item(), |
| 'char_coeffs_mean': char_coeffs.mean(dim=(0, 1)).cpu().tolist(), |
| 'refine_res_mean': refine_res.mean().item(), |
| 'refine_res_max': refine_res.max().item(), |
| } |
|
|
| |
| for d in range(D): |
| per_img = friction[:, :, d].reshape(B, -1) |
| cvs = per_img.std(dim=1) / (per_img.mean(dim=1) + 1e-8) |
| stats[f'friction_spatial_cv_mode{d}'] = cvs.mean().item() |
|
|
| return stats |
|
|
|
|
| def print_conduit_readout(stats, D=4): |
| """Pretty-print conduit telemetry.""" |
| print(f" S: [{', '.join(f'{v:.3f}' for v in stats['S_mean'])}]") |
| print(f" S_std: [{', '.join(f'{v:.4f}' for v in stats['S_std'])}]") |
| print(f" log_fric: [{', '.join(f'{v:.3f}' for v in stats['log_fric_mean'])}] " |
| f"Β± [{', '.join(f'{v:.3f}' for v in stats['log_fric_std'])}]") |
| print(f" fric_raw: mean={stats['friction_mean']:.1f} max={stats['friction_max']:.0f}") |
| print(f" settle: [{', '.join(f'{v:.2f}' for v in stats['settle_mean'])}] " |
| f"(>{2}: {stats['settle_frac_gt2']:.1%})") |
| print(f" char_c: [{', '.join(f'{v:.4f}' for v in stats['char_coeffs_mean'])}]") |
| print(f" refine: mean={stats['refine_res_mean']:.2e} max={stats['refine_res_max']:.2e}") |
| spatial_cvs = [stats.get(f'friction_spatial_cv_mode{d}', 0) for d in range(D)] |
| print(f" fric_cv: [{', '.join(f'{v:.4f}' for v in spatial_cvs)}]") |
|
|
|
|
| |
| |
| |
|
|
| def load_log(): |
| if os.path.exists(LOG_PATH): |
| with open(LOG_PATH) as f: |
| return json.load(f) |
| return {'version': VERSION, 'entries': []} |
|
|
|
|
| def save_log(log): |
| os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True) |
| with open(LOG_PATH, 'w') as f: |
| json.dump(log, f, indent=2) |
|
|
|
|
| |
| |
| |
|
|
| def save_checkpoint(model, opt, sched, epoch, val_mse, log, |
| path, is_best=False): |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| ckpt = { |
| 'config': { |
| 'V': CFG['V'], 'D': CFG['D'], 'patch_size': CFG['ps'], |
| 'hidden': CFG['hidden'], 'depth': CFG['depth'], |
| 'n_cross_layers': CFG['n_cross'], |
| 'stage_hidden': CFG.get('stage_hidden', 128), |
| 'stage_V': CFG.get('stage_V', 16), |
| 'img_size': CFG['img_size'], |
| 'model_type': 'v2', |
| }, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': opt.state_dict(), |
| 'scheduler_state_dict': sched.state_dict(), |
| 'epoch': epoch, |
| 'val_mse': val_mse, |
| } |
| torch.save(ckpt, path) |
| size_mb = os.path.getsize(path) / (1024 * 1024) |
| print(f" πΎ {path} ({size_mb:.1f}MB, ep{epoch}, MSE={val_mse:.6f})") |
|
|
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj=path, |
| path_in_repo=f'{VERSION}/checkpoints/{os.path.basename(path)}', |
| repo_id=HF_REPO, repo_type='model', |
| commit_message=f'{VERSION} ep{epoch} mse={val_mse:.6f}') |
| if is_best: |
| api.upload_file( |
| path_or_fileobj=path, |
| path_in_repo=f'{VERSION}/checkpoints/best.pt', |
| repo_id=HF_REPO, repo_type='model', |
| commit_message=f'{VERSION} BEST ep{epoch} mse={val_mse:.6f}') |
| save_log(log) |
| api.upload_file( |
| path_or_fileobj=LOG_PATH, |
| path_in_repo=f'{VERSION}/training_log.json', |
| repo_id=HF_REPO, repo_type='model', |
| commit_message=f'{VERSION} log ep{epoch}') |
| print(f" βοΈ Pushed ep{epoch}") |
| except Exception as e: |
| print(f" β οΈ Push failed: {e}") |
|
|
|
|
| |
| |
| |
|
|
| def train(): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| os.makedirs(LOCAL_DIR, exist_ok=True) |
|
|
| print("\n" + "=" * 70) |
| print(f"SVAE v2 CONDUIT TRAINER β {VERSION}") |
| print("=" * 70) |
|
|
| |
| D = CFG['D'] |
| model = PatchSVAEv2( |
| V=CFG['V'], D=D, ps=CFG['ps'], |
| hidden=CFG['hidden'], depth=CFG['depth'], |
| n_cross=CFG['n_cross'], |
| stage_hidden=CFG.get('stage_hidden', 128), |
| stage_V=CFG.get('stage_V', 16), |
| ).to(device) |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"\n Fresh PatchSVAEv2 from random init") |
| print(f" Total params: {n_params:,}") |
|
|
| |
| print(f"\n Dataset: 16 noise types, {CFG['ds_size']:,} samples/epoch") |
| print(f" Image size: {CFG['img_size']}Γ{CFG['img_size']}") |
| print(f" Batch size: {CFG['batch_size']}") |
|
|
| train_ds = OmegaNoiseDataset(size=CFG['ds_size'], img_size=CFG['img_size']) |
| val_ds = OmegaNoiseDataset(size=CFG['val_size'], img_size=CFG['img_size']) |
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=CFG['batch_size'], shuffle=True, |
| num_workers=4, pin_memory=True, drop_last=True) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=CFG['batch_size'], shuffle=False, |
| num_workers=4, pin_memory=True) |
|
|
| |
| opt = torch.optim.Adam(model.parameters(), lr=CFG['lr']) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=CFG['epochs']) |
|
|
| |
| target_cv = CFG['target_cv'] |
| cv_weight = CFG['cv_weight'] |
| boost = CFG['boost'] |
| sigma = CFG['sigma'] |
|
|
| |
| log = load_log() |
| best_mse = float('inf') |
|
|
| |
| print(f"\n Initial conduit profile:") |
| sample_batch = next(iter(val_loader))[0][:64] |
| init_stats = conduit_readout(model, sample_batch, device) |
| print_conduit_readout(init_stats, D) |
|
|
| |
| model.eval() |
| init_mse = 0 |
| init_n = 0 |
| with torch.no_grad(): |
| for imgs, _ in val_loader: |
| imgs = imgs.to(device) |
| out = model(imgs) |
| init_mse += F.mse_loss(out['recon'], imgs).item() * len(imgs) |
| init_n += len(imgs) |
| if init_n >= 2560: |
| break |
| init_mse /= init_n |
| print(f"\n Initial MSE (random decoder): {init_mse:.4f}") |
| print("=" * 70) |
|
|
| |
| |
| |
|
|
| for epoch in range(1, CFG['epochs'] + 1): |
|
|
| model.train() |
| total_loss, total_recon, n = 0, 0, 0 |
| last_cv = target_cv |
| t0 = time.time() |
|
|
| pbar = tqdm(train_loader, |
| desc=f"Ep {epoch}/{CFG['epochs']}", |
| bar_format='{l_bar}{bar:20}{r_bar}') |
|
|
| for batch_idx, (images, _) in enumerate(pbar): |
| images = images.to(device) |
| opt.zero_grad() |
| out = model(images) |
| recon_loss = F.mse_loss(out['recon'], images) |
|
|
| |
| with torch.no_grad(): |
| if batch_idx % 50 == 0: |
| cur_cv = cv_of(out['svd']['M'][0, 0]) |
| if cur_cv > 0: |
| last_cv = cur_cv |
| delta = last_cv - target_cv |
| prox = math.exp(-delta**2 / (2 * sigma**2)) |
|
|
| recon_w = 1.0 + boost * prox |
| cv_pen = cv_weight * (1.0 - prox) |
| loss = recon_w * recon_loss + cv_pen * (last_cv - target_cv)**2 |
| loss.backward() |
|
|
| torch.nn.utils.clip_grad_norm_( |
| model.cross_attn.parameters(), max_norm=0.5) |
|
|
| opt.step() |
|
|
| total_loss += loss.item() * len(images) |
| total_recon += recon_loss.item() * len(images) |
| n += len(images) |
| pbar.set_postfix_str(f"mse={recon_loss.item():.4f} cv={last_cv:.3f}") |
|
|
| sched.step() |
| epoch_time = time.time() - t0 |
|
|
| |
| model.eval() |
| val_mse, val_n = 0, 0 |
| with torch.no_grad(): |
| for imgs, _ in val_loader: |
| imgs = imgs.to(device) |
| out = model(imgs) |
| val_mse += F.mse_loss(out['recon'], imgs).item() * len(imgs) |
| val_n += len(imgs) |
| val_mse /= val_n |
|
|
| |
| with torch.no_grad(): |
| sample = next(iter(val_loader))[0][:64].to(device) |
| out = model(sample) |
| S_mean = out['svd']['S_orig'].mean(dim=(0, 1)) |
| S_coord = out['svd']['S'].mean(dim=(0, 1)) |
| erank = model.effective_rank( |
| out['svd']['S'].reshape(-1, D)).mean().item() |
| s_delta = model.s_delta(out['svd']['S_orig'], out['svd']['S']) |
|
|
| |
| cond_stats = conduit_readout(model, sample.cpu(), device) |
|
|
| |
| is_best = val_mse < best_mse |
| if is_best: |
| best_mse = val_mse |
|
|
| print(f"\n ep{epoch:3d} | recon={total_recon/n:.4f} val={val_mse:.4f} " |
| f"{'β
BEST' if is_best else ''} | " |
| f"er={erank:.2f} Sd={s_delta:.4f} cv={last_cv:.3f} | {epoch_time:.0f}s") |
| print_conduit_readout(cond_stats, D) |
|
|
| |
| if epoch % CFG['val_per_type_every'] == 0 or epoch <= 3: |
| type_mse = eval_per_type(model, CFG['img_size'], device) |
| type_str = " ".join( |
| f"{NOISE_NAMES[t][:4]}={v:.3f}" for t, v in sorted(type_mse.items())) |
| print(f" types: {type_str}") |
|
|
| |
| log['entries'].append({ |
| 'epoch': epoch, |
| 'train_mse': total_recon / n, |
| 'val_mse': val_mse, |
| 'cv': last_cv, |
| 'erank': erank, |
| 's_delta': s_delta, |
| 'S_mean': S_mean.cpu().tolist(), |
| 'conduit': cond_stats, |
| 'epoch_time': epoch_time, |
| 'lr': opt.param_groups[0]['lr'], |
| }) |
|
|
| |
| if is_best: |
| save_checkpoint(model, opt, sched, epoch, val_mse, log, |
| os.path.join(LOCAL_DIR, 'best.pt'), |
| is_best=True) |
|
|
| if epoch % CFG['save_every'] == 0: |
| save_checkpoint(model, opt, sched, epoch, val_mse, log, |
| os.path.join(LOCAL_DIR, f'epoch_{epoch:04d}.pt')) |
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(f"v2 CONDUIT TRAINING COMPLETE β {VERSION}") |
| print(f" Best MSE: {best_mse:.6f}") |
| print(f" Epochs: {CFG['epochs']}") |
| print(f" Params: {n_params:,}") |
| print(f"{'=' * 70}") |
|
|
| return model |
|
|
|
|
| if __name__ == "__main__": |
| torch.set_float32_matmul_precision('high') |
| train() |