| """ |
| IRIS Architecture Validation Tests |
| =================================== |
| Tests forward pass, training step, generation, and memory profile. |
| """ |
|
|
| import torch |
| import time |
| import sys |
| from iris_model import ( |
| IRIS, IRISConfig, create_iris_small, create_iris_tiny, create_iris_base, |
| count_parameters, estimate_memory_mb, |
| HaarDWT2D, HaarIDWT2D, WaveletVAE, IRISGenerator, GRFM |
| ) |
|
|
|
|
| def test_wavelet_transform(): |
| """Test Haar DWT/IDWT roundtrip.""" |
| print("=" * 60) |
| print("Test 1: Wavelet Transform Roundtrip") |
| print("=" * 60) |
| dwt = HaarDWT2D() |
| idwt = HaarIDWT2D() |
| |
| x = torch.randn(2, 3, 64, 64) |
| y = dwt(x) |
| x_recon = idwt(y) |
| |
| error = (x - x_recon).abs().max().item() |
| print(f" Input shape: {list(x.shape)}") |
| print(f" DWT shape: {list(y.shape)}") |
| print(f" Recon shape: {list(x_recon.shape)}") |
| print(f" Max error: {error:.2e}") |
| assert error < 1e-5, f"DWT roundtrip error too high: {error}" |
| print(" β
PASSED (lossless roundtrip)") |
| return True |
|
|
|
|
| def test_vae(): |
| """Test VAE encode/decode.""" |
| print("\n" + "=" * 60) |
| print("Test 2: Wavelet VAE") |
| print("=" * 60) |
| config = IRISConfig( |
| latent_channels=16, |
| latent_spatial=32, |
| vae_channels=[32, 64, 128, 256], |
| ) |
| vae = WaveletVAE(config) |
| |
| |
| |
| |
| |
| x = torch.randn(2, 3, 256, 256) |
| |
| z, mean, logvar = vae.encode(x) |
| x_recon = vae.decode(z) |
| |
| print(f" Input shape: {list(x.shape)}") |
| print(f" Latent shape: {list(z.shape)}") |
| print(f" Recon shape: {list(x_recon.shape)}") |
| print(f" Compression: {x.numel() / z.numel():.1f}Γ") |
| |
| vae_params = sum(p.numel() for p in vae.parameters()) |
| print(f" VAE params: {vae_params:,}") |
| print(f" VAE memory: {vae_params * 2 / 1024 / 1024:.1f} MB (fp16)") |
| print(" β
PASSED") |
| return True |
|
|
|
|
| def test_grfm(): |
| """Test GRFM module independently.""" |
| print("\n" + "=" * 60) |
| print("Test 3: GRFM (Gated Recurrent Fourier Mixer)") |
| print("=" * 60) |
| config = IRISConfig( |
| hidden_dim=256, |
| num_heads=4, |
| fourier_num_blocks=4, |
| recurrence_dim=128, |
| manhattan_window=8, |
| ) |
| grfm = GRFM(config) |
| |
| B, H, W, D = 2, 8, 8, 256 |
| x = torch.randn(B, H * W, D) |
| |
| t0 = time.time() |
| out = grfm(x, H, W) |
| t1 = time.time() |
| |
| print(f" Input: [B={B}, N={H*W}, D={D}]") |
| print(f" Output: {list(out.shape)}") |
| print(f" Time: {(t1-t0)*1000:.1f} ms") |
| |
| grfm_params = sum(p.numel() for p in grfm.parameters()) |
| print(f" Params: {grfm_params:,}") |
| |
| |
| loss = out.sum() |
| loss.backward() |
| grad_ok = all(p.grad is not None for p in grfm.parameters() if p.requires_grad) |
| print(f" Gradients: {'β
All flowing' if grad_ok else 'β Some missing'}") |
| print(" β
PASSED") |
| return True |
|
|
|
|
| def test_generator_forward(): |
| """Test generator forward pass.""" |
| print("\n" + "=" * 60) |
| print("Test 4: Generator Forward Pass") |
| print("=" * 60) |
| config = IRISConfig( |
| latent_channels=8, |
| latent_spatial=8, |
| hidden_dim=256, |
| num_heads=4, |
| head_dim=64, |
| num_prelude_blocks=1, |
| num_core_layers=2, |
| num_coda_blocks=1, |
| default_iterations=4, |
| fourier_num_blocks=4, |
| recurrence_dim=128, |
| manhattan_window=8, |
| text_dim=768, |
| patch_size=2, |
| ) |
| gen = IRISGenerator(config) |
| |
| B = 2 |
| z_t = torch.randn(B, config.latent_channels, config.latent_spatial, config.latent_spatial) |
| t = torch.rand(B) |
| text_tokens = torch.randn(B, 77, config.text_dim) |
| |
| |
| for r in [2, 4, 8]: |
| t0 = time.time() |
| v_pred = gen(z_t, t, text_tokens, num_iterations=r) |
| t1 = time.time() |
| print(f" r={r:2d}: output={list(v_pred.shape)}, time={1000*(t1-t0):.0f}ms") |
| |
| assert v_pred.shape == z_t.shape, "Output shape mismatch" |
| |
| gen_params = sum(p.numel() for p in gen.parameters()) |
| print(f" Generator params: {gen_params:,}") |
| print(f" Note: Core block shared across all iterations!") |
| print(" β
PASSED") |
| return True |
|
|
|
|
| def test_training_step(): |
| """Test full training step with loss computation.""" |
| print("\n" + "=" * 60) |
| print("Test 5: Training Step") |
| print("=" * 60) |
| config = IRISConfig( |
| latent_channels=8, |
| latent_spatial=8, |
| hidden_dim=256, |
| num_heads=4, |
| head_dim=64, |
| num_prelude_blocks=1, |
| num_core_layers=2, |
| num_coda_blocks=1, |
| default_iterations=4, |
| fourier_num_blocks=4, |
| recurrence_dim=128, |
| manhattan_window=8, |
| text_dim=768, |
| patch_size=2, |
| vae_channels=[16, 32, 64, 128], |
| ) |
| model = IRIS(config) |
| |
| |
| B = 2 |
| |
| |
| |
| images = torch.randn(B, 3, 128, 128) |
| text_tokens = torch.randn(B, 77, config.text_dim) |
| |
| |
| t0 = time.time() |
| result = model.train_step(images, text_tokens, num_iterations=4) |
| t1 = time.time() |
| |
| print(f" Loss: {result['loss'].item():.4f}") |
| print(f" Velocity loss: {result['velocity_loss']:.4f}") |
| print(f" KL loss: {result['kl_loss']:.4f}") |
| print(f" Mean t: {result['mean_t']:.3f}") |
| print(f" Time: {(t1-t0)*1000:.0f} ms") |
| |
| |
| t0 = time.time() |
| result['loss'].backward() |
| t1 = time.time() |
| print(f" Backward time: {(t1-t0)*1000:.0f} ms") |
| |
| |
| n_grads = sum(1 for p in model.parameters() if p.grad is not None) |
| n_params = sum(1 for p in model.parameters()) |
| print(f" Gradients: {n_grads}/{n_params} params have gradients") |
| print(" β
PASSED") |
| return True |
|
|
|
|
| def test_generation(): |
| """Test full generation pipeline.""" |
| print("\n" + "=" * 60) |
| print("Test 6: Image Generation Pipeline") |
| print("=" * 60) |
| config = IRISConfig( |
| latent_channels=8, |
| latent_spatial=8, |
| hidden_dim=256, |
| num_heads=4, |
| head_dim=64, |
| num_prelude_blocks=1, |
| num_core_layers=2, |
| num_coda_blocks=1, |
| default_iterations=4, |
| fourier_num_blocks=4, |
| recurrence_dim=128, |
| manhattan_window=8, |
| text_dim=768, |
| patch_size=2, |
| vae_channels=[16, 32, 64, 128], |
| ) |
| model = IRIS(config) |
| model.eval() |
| |
| B = 2 |
| text_tokens = torch.randn(B, 77, config.text_dim) |
| |
| |
| for steps, iters in [(1, 4), (4, 4), (4, 8)]: |
| t0 = time.time() |
| with torch.no_grad(): |
| images = model.generate( |
| text_tokens, |
| num_steps=steps, |
| num_iterations=iters, |
| cfg_scale=1.0, |
| seed=42 |
| ) |
| t1 = time.time() |
| print(f" steps={steps}, iters={iters}: shape={list(images.shape)}, " |
| f"range=[{images.min():.2f}, {images.max():.2f}], time={1000*(t1-t0):.0f}ms") |
| |
| assert images.shape == (B, 3, 128, 128), f"Unexpected output shape: {images.shape}" |
| print(" β
PASSED") |
| return True |
|
|
|
|
| def test_adaptive_compute(): |
| """Test that different iteration counts produce different results.""" |
| print("\n" + "=" * 60) |
| print("Test 7: Adaptive Compute Budget") |
| print("=" * 60) |
| config = IRISConfig( |
| latent_channels=8, |
| latent_spatial=8, |
| hidden_dim=256, |
| num_heads=4, |
| head_dim=64, |
| num_prelude_blocks=1, |
| num_core_layers=2, |
| num_coda_blocks=1, |
| default_iterations=4, |
| fourier_num_blocks=4, |
| recurrence_dim=128, |
| manhattan_window=8, |
| text_dim=768, |
| patch_size=2, |
| vae_channels=[16, 32, 64, 128], |
| ) |
| model = IRIS(config) |
| model.eval() |
| |
| text_tokens = torch.randn(1, 77, config.text_dim) |
| |
| |
| |
| |
| with torch.no_grad(): |
| model.generator.output_proj.weight.normal_(0, 0.02) |
| for name, param in model.generator.core.named_parameters(): |
| if 'adaln' in name: |
| param.normal_(0, 0.1) |
| |
| results = {} |
| for r in [2, 4, 8, 12]: |
| with torch.no_grad(): |
| img = model.generate(text_tokens, num_steps=2, num_iterations=r, |
| cfg_scale=1.0, seed=42) |
| results[r] = img |
| |
| |
| diff_4_8 = (results[4] - results[8]).abs().mean().item() |
| diff_8_12 = (results[8] - results[12]).abs().mean().item() |
| diff_2_12 = (results[2] - results[12]).abs().mean().item() |
| |
| print(f" Diff(r=4, r=8): {diff_4_8:.4f}") |
| print(f" Diff(r=8, r=12): {diff_8_12:.4f}") |
| print(f" Diff(r=2, r=12): {diff_2_12:.4f}") |
| print(f" More iterations β more refinement: {'β
' if diff_2_12 > diff_8_12 else 'β οΈ'}") |
| |
| |
| assert diff_4_8 > 0, "r=4 and r=8 should differ" |
| assert diff_8_12 > 0, "r=8 and r=12 should differ" |
| print(" β
PASSED") |
| return True |
|
|
|
|
| def test_memory_profile(): |
| """Profile memory usage for mobile deployment.""" |
| print("\n" + "=" * 60) |
| print("Test 8: Memory Profile for Mobile Deployment") |
| print("=" * 60) |
| |
| for name, create_fn in [("IRIS-Tiny", create_iris_tiny), |
| ("IRIS-Small", create_iris_small)]: |
| model = create_fn() |
| |
| |
| vae_params = sum(p.numel() for p in model.vae.parameters()) |
| gen_params = sum(p.numel() for p in model.generator.parameters()) |
| |
| |
| core_params = sum(p.numel() for p in model.generator.core.parameters()) |
| prelude_params = sum(p.numel() for p in model.generator.prelude.parameters()) |
| coda_params = sum(p.numel() for p in model.generator.coda.parameters()) |
| |
| vae_mb = vae_params * 2 / 1024 / 1024 |
| gen_mb = gen_params * 2 / 1024 / 1024 |
| core_mb = core_params * 2 / 1024 / 1024 |
| |
| |
| model_mb = (vae_params + gen_params) * 2 / 1024 / 1024 |
| text_enc_mb = 156 |
| activation_mb = 50 |
| overhead_mb = 300 |
| total_mb = model_mb + text_enc_mb + activation_mb + overhead_mb |
| |
| print(f"\n {name}:") |
| print(f" VAE: {vae_params:>10,} params = {vae_mb:>6.1f} MB") |
| print(f" Generator: {gen_params:>10,} params = {gen_mb:>6.1f} MB") |
| print(f" Prelude: {prelude_params:>10,}") |
| print(f" Core: {core_params:>10,} (shared, iterated r times)") |
| print(f" Coda: {coda_params:>10,}") |
| print(f" ββββββββββββββββββββββββββββββββ") |
| print(f" Model total: {model_mb:>6.1f} MB (fp16)") |
| print(f" + CLIP-L/14: {text_enc_mb:>6.1f} MB") |
| print(f" + Activations: {activation_mb:>6.1f} MB") |
| print(f" + OS overhead: {overhead_mb:>6.1f} MB") |
| print(f" βββββββββββββββββββββββββββββββ") |
| print(f" TOTAL INFERENCE: {total_mb:>6.1f} MB") |
| print(f" Fits in 3GB: {'β
YES' if total_mb < 3000 else 'β NO'}") |
| print(f" Fits in 4GB: {'β
YES' if total_mb < 4000 else 'β NO'}") |
| |
| print("\n β
PASSED") |
| return True |
|
|
|
|
| def test_effective_depth(): |
| """Demonstrate the effective depth advantage.""" |
| print("\n" + "=" * 60) |
| print("Test 9: Effective Depth Analysis") |
| print("=" * 60) |
| |
| model = create_iris_small() |
| config = model.config |
| |
| |
| core_params = sum(p.numel() for p in model.generator.core.parameters()) |
| total_unique = sum(p.numel() for p in model.parameters()) |
| |
| layers_per_iteration = config.num_core_layers |
| |
| print(f" Architecture: Prelude({config.num_prelude_blocks}) β " |
| f"Core({config.num_core_layers} layers Γ r iters) β " |
| f"Coda({config.num_coda_blocks})") |
| print(f" Unique params: {total_unique:,}") |
| print(f" Core params: {core_params:,} (shared)") |
| print() |
| |
| for r in [4, 8, 12, 16]: |
| effective_layers = config.num_prelude_blocks + r * layers_per_iteration + config.num_coda_blocks |
| effective_params = total_unique + (r - 1) * core_params |
| |
| print(f" r={r:2d}: {effective_layers} effective layers, " |
| f"~{effective_params/1e6:.0f}M effective params, " |
| f"from {total_unique/1e6:.0f}M unique") |
| |
| print(f"\n β 16Γ iteration gives {(total_unique + 15*core_params)/total_unique:.1f}Γ " |
| f"effective capacity from same model!") |
| print(" β
PASSED") |
| return True |
|
|
|
|
| if __name__ == "__main__": |
| print("π¬ IRIS Architecture Validation Suite") |
| print("=" * 60) |
| |
| tests = [ |
| test_wavelet_transform, |
| test_vae, |
| test_grfm, |
| test_generator_forward, |
| test_training_step, |
| test_generation, |
| test_adaptive_compute, |
| test_memory_profile, |
| test_effective_depth, |
| ] |
| |
| passed = 0 |
| failed = 0 |
| for test in tests: |
| try: |
| if test(): |
| passed += 1 |
| except Exception as e: |
| print(f" β FAILED: {e}") |
| import traceback |
| traceback.print_exc() |
| failed += 1 |
| |
| print(f"\n{'=' * 60}") |
| print(f"Results: {passed} passed, {failed} failed out of {len(tests)} tests") |
| print(f"{'=' * 60}") |
| |
| if failed > 0: |
| sys.exit(1) |
|
|