geolip-conduit-experiments / svae_cadence.py
AbstractPhil's picture
Update svae_cadence.py
e37d8c5 verified
"""
SVAE v2 Conduit Trainer β€” Prototype
=====================================
Train PatchSVAEv2 from random init on noise.
The decoder MUST reconstruct from decomposed spectral + conduit bundles.
No M_hat shortcut. Every conduit element is load-bearing.
Readouts per epoch:
Standard: MSE, S profile, erank, s_delta, CV
Conduit: friction stats, settle distribution, char_coeff profile,
per-mode reconstruction contribution
Should converge rapidly or fail β€” we'll know within 10 epochs.
Usage:
python train_v2_conduit.py
"""
import os
import math
import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
try:
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])
except Exception:
pass
from geolip_svae.model import cv_of, extract_patches, stitch_patches
from geolip_svae.model_v2 import PatchSVAEv2
# ═══════════════════════════════════════════════════════════════
# CONFIG
# ═══════════════════════════════════════════════════════════════
HF_REPO = 'AbstractPhil/geolip-SVAE'
VERSION = 'version2_v2_conduit_proto_2'
LOCAL_DIR = f'/content/{VERSION}_checkpoints'
LOG_PATH = os.path.join(LOCAL_DIR, 'training_log.json')
CFG = dict(
# Architecture (inherited from Fresnel v50)
V=16, D=4, ps=4, hidden=384, depth=4, n_cross=2,
stage_hidden=128, stage_V=64,
# Training
img_size=64,
batch_size=256,
lr=3e-4,
epochs=50,
ds_size=1280000,
val_size=10000,
# CV soft hand
target_cv=0.2915,
cv_weight=0.3,
boost=0.5,
sigma=0.15,
# Checkpointing
save_every=5,
val_per_type_every=5,
)
# ═══════════════════════════════════════════════════════════════
# NOISE DATASET (16 types, same as Freckles)
# ═══════════════════════════════════════════════════════════════
NOISE_NAMES = {
0: 'gaussian', 1: 'uniform', 2: 'uniform_sc', 3: 'poisson',
4: 'pink', 5: 'brown', 6: 'salt_pepper', 7: 'sparse',
8: 'block', 9: 'gradient', 10: 'checker', 11: 'mixed',
12: 'structural', 13: 'cauchy', 14: 'exponential', 15: 'laplace',
}
def _pink(shape):
w = torch.randn(shape)
S = torch.fft.rfft2(w)
h, ww = shape[-2], shape[-1]
fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1)
fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1)
return torch.fft.irfft2(S / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8), s=(h, ww))
def _brown(shape):
w = torch.randn(shape)
S = torch.fft.rfft2(w)
h, ww = shape[-2], shape[-1]
fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1)
fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1)
return torch.fft.irfft2(S / (fx**2 + fy**2).clamp(min=1e-8), s=(h, ww))
def _gen_noise(noise_type, s, rng):
if noise_type == 0: return torch.randn(3, s, s)
elif noise_type == 1: return torch.rand(3, s, s) * 2 - 1
elif noise_type == 2: return (torch.rand(3, s, s) - 0.5) * 4
elif noise_type == 3:
lam = rng.uniform(0.5, 20.0)
return torch.poisson(torch.full((3, s, s), lam)) / lam - 1.0
elif noise_type == 4:
img = _pink((3, s, s)); return img / (img.std() + 1e-8)
elif noise_type == 5:
img = _brown((3, s, s)); return img / (img.std() + 1e-8)
elif noise_type == 6:
return torch.where(torch.rand(3, s, s) > 0.5,
torch.ones(3, s, s) * 2, -torch.ones(3, s, s) * 2) + torch.randn(3, s, s) * 0.1
elif noise_type == 7:
return torch.randn(3, s, s) * (torch.rand(3, s, s) > 0.9).float() * 3
elif noise_type == 8:
b = rng.randint(2, max(3, s // 2))
sm = torch.randn(3, s // b + 1, s // b + 1)
return F.interpolate(sm.unsqueeze(0), size=s, mode='nearest').squeeze(0)
elif noise_type == 9:
gy = torch.linspace(-2, 2, s).unsqueeze(1).expand(s, s)
gx = torch.linspace(-2, 2, s).unsqueeze(0).expand(s, s)
a = rng.uniform(0, 2 * math.pi)
return (math.cos(a) * gx + math.sin(a) * gy).unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.5
elif noise_type == 10:
cs = rng.randint(2, max(3, s // 2))
cy = torch.arange(s) // cs; cx = torch.arange(s) // cs
return ((cy.unsqueeze(1) + cx.unsqueeze(0)) % 2).float().unsqueeze(0).expand(3, -1, -1) * 2 - 1 + torch.randn(3, s, s) * 0.3
elif noise_type == 11:
alpha = rng.uniform(0.2, 0.8)
return alpha * torch.randn(3, s, s) + (1 - alpha) * (torch.rand(3, s, s) * 2 - 1)
elif noise_type == 12:
img = torch.zeros(3, s, s); h2 = s // 2
img[:, :h2, :h2] = torch.randn(3, h2, h2)
img[:, :h2, h2:] = torch.rand(3, h2, h2) * 2 - 1
img[:, h2:, :h2] = _pink((3, h2, h2)) / 2
img[:, h2:, h2:] = torch.where(torch.rand(3, h2, h2) > 0.5,
torch.ones(3, h2, h2), -torch.ones(3, h2, h2))
return img
elif noise_type == 13:
return torch.tan(math.pi * (torch.rand(3, s, s) - 0.5)).clamp(-3, 3)
elif noise_type == 14:
return torch.empty(3, s, s).exponential_(1.0) - 1.0
elif noise_type == 15:
u = torch.rand(3, s, s) - 0.5
return -torch.sign(u) * torch.log1p(-2 * u.abs())
return torch.randn(3, s, s)
class OmegaNoiseDataset(torch.utils.data.Dataset):
def __init__(self, size=1280000, img_size=64):
self.size = size
self.img_size = img_size
self._rng = np.random.RandomState(42)
self._call_count = 0
def __len__(self):
return self.size
def __getitem__(self, idx):
self._call_count += 1
if self._call_count % 1000 == 0:
self._rng = np.random.RandomState(int.from_bytes(os.urandom(4), 'big'))
torch.manual_seed(int.from_bytes(os.urandom(4), 'big'))
noise_type = idx % 16
img = _gen_noise(noise_type, self.img_size, self._rng).clamp(-4, 4)
return img.float(), noise_type
def eval_per_type(model, img_size, device, n_per=32):
rng = np.random.RandomState(99)
model.eval()
results = {}
with torch.no_grad():
for t in range(16):
imgs = torch.stack([_gen_noise(t, img_size, rng).clamp(-4, 4)
for _ in range(n_per)]).to(device)
out = model(imgs)
results[t] = F.mse_loss(out['recon'], imgs).item()
return results
# ═══════════════════════════════════════════════════════════════
# CONDUIT READOUTS
# ═══════════════════════════════════════════════════════════════
def conduit_readout(model, images, device):
"""Extract and summarize conduit telemetry from a batch."""
model.eval()
with torch.no_grad():
out = model(images.to(device))
packet = model.last_conduit_packet
S = out['svd']['S_orig'] # (B, N, D)
B, N, D = S.shape
friction = packet.friction.reshape(B, N, D)
settle = packet.settle.reshape(B, N, D)
char_coeffs = packet.char_coeffs.reshape(B, N, D)
ext_order = packet.extraction_order.reshape(B, N, D)
refine_res = packet.refinement_residual.reshape(B, N)
# Log-friction for readable stats
log_fric = torch.log1p(friction)
stats = {
'S_mean': S.mean(dim=(0, 1)).cpu().tolist(),
'S_std': S.std(dim=(0, 1)).cpu().tolist(),
'friction_mean': friction.mean().item(),
'friction_max': friction.max().item(),
'friction_std': friction.std().item(),
'log_fric_mean': log_fric.mean(dim=(0, 1)).cpu().tolist(),
'log_fric_std': log_fric.std(dim=(0, 1)).cpu().tolist(),
'settle_mean': settle.mean(dim=(0, 1)).cpu().tolist(),
'settle_frac_gt2': (settle > 2).float().mean().item(),
'char_coeffs_mean': char_coeffs.mean(dim=(0, 1)).cpu().tolist(),
'refine_res_mean': refine_res.mean().item(),
'refine_res_max': refine_res.max().item(),
}
# Per-mode friction spatial CV
for d in range(D):
per_img = friction[:, :, d].reshape(B, -1)
cvs = per_img.std(dim=1) / (per_img.mean(dim=1) + 1e-8)
stats[f'friction_spatial_cv_mode{d}'] = cvs.mean().item()
return stats
def print_conduit_readout(stats, D=4):
"""Pretty-print conduit telemetry."""
print(f" S: [{', '.join(f'{v:.3f}' for v in stats['S_mean'])}]")
print(f" S_std: [{', '.join(f'{v:.4f}' for v in stats['S_std'])}]")
print(f" log_fric: [{', '.join(f'{v:.3f}' for v in stats['log_fric_mean'])}] "
f"Β± [{', '.join(f'{v:.3f}' for v in stats['log_fric_std'])}]")
print(f" fric_raw: mean={stats['friction_mean']:.1f} max={stats['friction_max']:.0f}")
print(f" settle: [{', '.join(f'{v:.2f}' for v in stats['settle_mean'])}] "
f"(>{2}: {stats['settle_frac_gt2']:.1%})")
print(f" char_c: [{', '.join(f'{v:.4f}' for v in stats['char_coeffs_mean'])}]")
print(f" refine: mean={stats['refine_res_mean']:.2e} max={stats['refine_res_max']:.2e}")
spatial_cvs = [stats.get(f'friction_spatial_cv_mode{d}', 0) for d in range(D)]
print(f" fric_cv: [{', '.join(f'{v:.4f}' for v in spatial_cvs)}]")
# ═══════════════════════════════════════════════════════════════
# LOGGING
# ═══════════════════════════════════════════════════════════════
def load_log():
if os.path.exists(LOG_PATH):
with open(LOG_PATH) as f:
return json.load(f)
return {'version': VERSION, 'entries': []}
def save_log(log):
os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True)
with open(LOG_PATH, 'w') as f:
json.dump(log, f, indent=2)
# ═══════════════════════════════════════════════════════════════
# SAVE & PUSH
# ═══════════════════════════════════════════════════════════════
def save_checkpoint(model, opt, sched, epoch, val_mse, log,
path, is_best=False):
os.makedirs(os.path.dirname(path), exist_ok=True)
ckpt = {
'config': {
'V': CFG['V'], 'D': CFG['D'], 'patch_size': CFG['ps'],
'hidden': CFG['hidden'], 'depth': CFG['depth'],
'n_cross_layers': CFG['n_cross'],
'stage_hidden': CFG.get('stage_hidden', 128),
'stage_V': CFG.get('stage_V', 16),
'img_size': CFG['img_size'],
'model_type': 'v2',
},
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'scheduler_state_dict': sched.state_dict(),
'epoch': epoch,
'val_mse': val_mse,
}
torch.save(ckpt, path)
size_mb = os.path.getsize(path) / (1024 * 1024)
print(f" πŸ’Ύ {path} ({size_mb:.1f}MB, ep{epoch}, MSE={val_mse:.6f})")
try:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=path,
path_in_repo=f'{VERSION}/checkpoints/{os.path.basename(path)}',
repo_id=HF_REPO, repo_type='model',
commit_message=f'{VERSION} ep{epoch} mse={val_mse:.6f}')
if is_best:
api.upload_file(
path_or_fileobj=path,
path_in_repo=f'{VERSION}/checkpoints/best.pt',
repo_id=HF_REPO, repo_type='model',
commit_message=f'{VERSION} BEST ep{epoch} mse={val_mse:.6f}')
save_log(log)
api.upload_file(
path_or_fileobj=LOG_PATH,
path_in_repo=f'{VERSION}/training_log.json',
repo_id=HF_REPO, repo_type='model',
commit_message=f'{VERSION} log ep{epoch}')
print(f" ☁️ Pushed ep{epoch}")
except Exception as e:
print(f" ⚠️ Push failed: {e}")
# ═══════════════════════════════════════════════════════════════
# TRAINING
# ═══════════════════════════════════════════════════════════════
def train():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs(LOCAL_DIR, exist_ok=True)
print("\n" + "=" * 70)
print(f"SVAE v2 CONDUIT TRAINER β€” {VERSION}")
print("=" * 70)
# ── Fresh v2 model from random init ──
D = CFG['D']
model = PatchSVAEv2(
V=CFG['V'], D=D, ps=CFG['ps'],
hidden=CFG['hidden'], depth=CFG['depth'],
n_cross=CFG['n_cross'],
stage_hidden=CFG.get('stage_hidden', 128),
stage_V=CFG.get('stage_V', 16),
).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"\n Fresh PatchSVAEv2 from random init")
print(f" Total params: {n_params:,}")
# ── Data ──
print(f"\n Dataset: 16 noise types, {CFG['ds_size']:,} samples/epoch")
print(f" Image size: {CFG['img_size']}Γ—{CFG['img_size']}")
print(f" Batch size: {CFG['batch_size']}")
train_ds = OmegaNoiseDataset(size=CFG['ds_size'], img_size=CFG['img_size'])
val_ds = OmegaNoiseDataset(size=CFG['val_size'], img_size=CFG['img_size'])
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=CFG['batch_size'], shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=CFG['batch_size'], shuffle=False,
num_workers=4, pin_memory=True)
# ── Optimizer (all params β€” full model training) ──
opt = torch.optim.Adam(model.parameters(), lr=CFG['lr'])
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=CFG['epochs'])
# CV soft hand
target_cv = CFG['target_cv']
cv_weight = CFG['cv_weight']
boost = CFG['boost']
sigma = CFG['sigma']
# Log
log = load_log()
best_mse = float('inf')
# ── Initial conduit readout ──
print(f"\n Initial conduit profile:")
sample_batch = next(iter(val_loader))[0][:64]
init_stats = conduit_readout(model, sample_batch, device)
print_conduit_readout(init_stats, D)
# ── Initial MSE (will be terrible β€” decoder is random) ──
model.eval()
init_mse = 0
init_n = 0
with torch.no_grad():
for imgs, _ in val_loader:
imgs = imgs.to(device)
out = model(imgs)
init_mse += F.mse_loss(out['recon'], imgs).item() * len(imgs)
init_n += len(imgs)
if init_n >= 2560:
break
init_mse /= init_n
print(f"\n Initial MSE (random decoder): {init_mse:.4f}")
print("=" * 70)
# ═══════════════════════════════════════════════════════════
# TRAINING LOOP
# ═══════════════════════════════════════════════════════════
for epoch in range(1, CFG['epochs'] + 1):
model.train()
total_loss, total_recon, n = 0, 0, 0
last_cv = target_cv
t0 = time.time()
pbar = tqdm(train_loader,
desc=f"Ep {epoch}/{CFG['epochs']}",
bar_format='{l_bar}{bar:20}{r_bar}')
for batch_idx, (images, _) in enumerate(pbar):
images = images.to(device)
opt.zero_grad()
out = model(images)
recon_loss = F.mse_loss(out['recon'], images)
# CV soft hand
with torch.no_grad():
if batch_idx % 50 == 0:
cur_cv = cv_of(out['svd']['M'][0, 0])
if cur_cv > 0:
last_cv = cur_cv
delta = last_cv - target_cv
prox = math.exp(-delta**2 / (2 * sigma**2))
recon_w = 1.0 + boost * prox
cv_pen = cv_weight * (1.0 - prox)
loss = recon_w * recon_loss + cv_pen * (last_cv - target_cv)**2
loss.backward()
torch.nn.utils.clip_grad_norm_(
model.cross_attn.parameters(), max_norm=0.5)
opt.step()
total_loss += loss.item() * len(images)
total_recon += recon_loss.item() * len(images)
n += len(images)
pbar.set_postfix_str(f"mse={recon_loss.item():.4f} cv={last_cv:.3f}")
sched.step()
epoch_time = time.time() - t0
# ── Validation ──
model.eval()
val_mse, val_n = 0, 0
with torch.no_grad():
for imgs, _ in val_loader:
imgs = imgs.to(device)
out = model(imgs)
val_mse += F.mse_loss(out['recon'], imgs).item() * len(imgs)
val_n += len(imgs)
val_mse /= val_n
# ── Geometry snapshot ──
with torch.no_grad():
sample = next(iter(val_loader))[0][:64].to(device)
out = model(sample)
S_mean = out['svd']['S_orig'].mean(dim=(0, 1))
S_coord = out['svd']['S'].mean(dim=(0, 1))
erank = model.effective_rank(
out['svd']['S'].reshape(-1, D)).mean().item()
s_delta = model.s_delta(out['svd']['S_orig'], out['svd']['S'])
# ── Conduit readout ──
cond_stats = conduit_readout(model, sample.cpu(), device)
# ── Print ──
is_best = val_mse < best_mse
if is_best:
best_mse = val_mse
print(f"\n ep{epoch:3d} | recon={total_recon/n:.4f} val={val_mse:.4f} "
f"{'β˜… BEST' if is_best else ''} | "
f"er={erank:.2f} Sd={s_delta:.4f} cv={last_cv:.3f} | {epoch_time:.0f}s")
print_conduit_readout(cond_stats, D)
# ── Per-type eval ──
if epoch % CFG['val_per_type_every'] == 0 or epoch <= 3:
type_mse = eval_per_type(model, CFG['img_size'], device)
type_str = " ".join(
f"{NOISE_NAMES[t][:4]}={v:.3f}" for t, v in sorted(type_mse.items()))
print(f" types: {type_str}")
# ── Log entry ──
log['entries'].append({
'epoch': epoch,
'train_mse': total_recon / n,
'val_mse': val_mse,
'cv': last_cv,
'erank': erank,
's_delta': s_delta,
'S_mean': S_mean.cpu().tolist(),
'conduit': cond_stats,
'epoch_time': epoch_time,
'lr': opt.param_groups[0]['lr'],
})
# ── Checkpoint ──
if is_best:
save_checkpoint(model, opt, sched, epoch, val_mse, log,
os.path.join(LOCAL_DIR, 'best.pt'),
is_best=True)
if epoch % CFG['save_every'] == 0:
save_checkpoint(model, opt, sched, epoch, val_mse, log,
os.path.join(LOCAL_DIR, f'epoch_{epoch:04d}.pt'))
# ═══════════════════════════════════════════════════════════
# DONE
# ═══════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(f"v2 CONDUIT TRAINING COMPLETE β€” {VERSION}")
print(f" Best MSE: {best_mse:.6f}")
print(f" Epochs: {CFG['epochs']}")
print(f" Params: {n_params:,}")
print(f"{'=' * 70}")
return model
if __name__ == "__main__":
torch.set_float32_matmul_precision('high')
train()