IRIS-architecture / test_iris.py
asdf98's picture
Add test_iris.py
89579fd verified
"""
IRIS Architecture Validation Tests
===================================
Tests forward pass, training step, generation, and memory profile.
"""
import torch
import time
import sys
from iris_model import (
IRIS, IRISConfig, create_iris_small, create_iris_tiny, create_iris_base,
count_parameters, estimate_memory_mb,
HaarDWT2D, HaarIDWT2D, WaveletVAE, IRISGenerator, GRFM
)
def test_wavelet_transform():
"""Test Haar DWT/IDWT roundtrip."""
print("=" * 60)
print("Test 1: Wavelet Transform Roundtrip")
print("=" * 60)
dwt = HaarDWT2D()
idwt = HaarIDWT2D()
x = torch.randn(2, 3, 64, 64)
y = dwt(x)
x_recon = idwt(y)
error = (x - x_recon).abs().max().item()
print(f" Input shape: {list(x.shape)}")
print(f" DWT shape: {list(y.shape)}")
print(f" Recon shape: {list(x_recon.shape)}")
print(f" Max error: {error:.2e}")
assert error < 1e-5, f"DWT roundtrip error too high: {error}"
print(" βœ… PASSED (lossless roundtrip)")
return True
def test_vae():
"""Test VAE encode/decode."""
print("\n" + "=" * 60)
print("Test 2: Wavelet VAE")
print("=" * 60)
config = IRISConfig(
latent_channels=16,
latent_spatial=32,
vae_channels=[32, 64, 128, 256],
)
vae = WaveletVAE(config)
# Input: 256Γ—256 images (will be compressed to 16Γ—16Γ—16 latent by VAE alone,
# but DWT first halves to 128Γ—128, then 3 downsamples = 16Γ—16)
# Actually: DWT gives 12Γ—128Γ—128, then conv_in β†’ 32Γ—128Γ—128
# Down1: 64Γ—64, Down2: 32Γ—32, Down3: 16Γ—16
x = torch.randn(2, 3, 256, 256)
z, mean, logvar = vae.encode(x)
x_recon = vae.decode(z)
print(f" Input shape: {list(x.shape)}")
print(f" Latent shape: {list(z.shape)}")
print(f" Recon shape: {list(x_recon.shape)}")
print(f" Compression: {x.numel() / z.numel():.1f}Γ—")
vae_params = sum(p.numel() for p in vae.parameters())
print(f" VAE params: {vae_params:,}")
print(f" VAE memory: {vae_params * 2 / 1024 / 1024:.1f} MB (fp16)")
print(" βœ… PASSED")
return True
def test_grfm():
"""Test GRFM module independently."""
print("\n" + "=" * 60)
print("Test 3: GRFM (Gated Recurrent Fourier Mixer)")
print("=" * 60)
config = IRISConfig(
hidden_dim=256,
num_heads=4,
fourier_num_blocks=4,
recurrence_dim=128,
manhattan_window=8,
)
grfm = GRFM(config)
B, H, W, D = 2, 8, 8, 256
x = torch.randn(B, H * W, D)
t0 = time.time()
out = grfm(x, H, W)
t1 = time.time()
print(f" Input: [B={B}, N={H*W}, D={D}]")
print(f" Output: {list(out.shape)}")
print(f" Time: {(t1-t0)*1000:.1f} ms")
grfm_params = sum(p.numel() for p in grfm.parameters())
print(f" Params: {grfm_params:,}")
# Test gradient flow
loss = out.sum()
loss.backward()
grad_ok = all(p.grad is not None for p in grfm.parameters() if p.requires_grad)
print(f" Gradients: {'βœ… All flowing' if grad_ok else '❌ Some missing'}")
print(" βœ… PASSED")
return True
def test_generator_forward():
"""Test generator forward pass."""
print("\n" + "=" * 60)
print("Test 4: Generator Forward Pass")
print("=" * 60)
config = IRISConfig(
latent_channels=8,
latent_spatial=8,
hidden_dim=256,
num_heads=4,
head_dim=64,
num_prelude_blocks=1,
num_core_layers=2,
num_coda_blocks=1,
default_iterations=4,
fourier_num_blocks=4,
recurrence_dim=128,
manhattan_window=8,
text_dim=768,
patch_size=2,
)
gen = IRISGenerator(config)
B = 2
z_t = torch.randn(B, config.latent_channels, config.latent_spatial, config.latent_spatial)
t = torch.rand(B)
text_tokens = torch.randn(B, 77, config.text_dim)
# Test different iteration counts
for r in [2, 4, 8]:
t0 = time.time()
v_pred = gen(z_t, t, text_tokens, num_iterations=r)
t1 = time.time()
print(f" r={r:2d}: output={list(v_pred.shape)}, time={1000*(t1-t0):.0f}ms")
assert v_pred.shape == z_t.shape, "Output shape mismatch"
gen_params = sum(p.numel() for p in gen.parameters())
print(f" Generator params: {gen_params:,}")
print(f" Note: Core block shared across all iterations!")
print(" βœ… PASSED")
return True
def test_training_step():
"""Test full training step with loss computation."""
print("\n" + "=" * 60)
print("Test 5: Training Step")
print("=" * 60)
config = IRISConfig(
latent_channels=8,
latent_spatial=8, # VAE with DWT + 3 down blocks: 128->DWT->64->32->16->8
hidden_dim=256,
num_heads=4,
head_dim=64,
num_prelude_blocks=1,
num_core_layers=2,
num_coda_blocks=1,
default_iterations=4,
fourier_num_blocks=4,
recurrence_dim=128,
manhattan_window=8,
text_dim=768,
patch_size=2,
vae_channels=[16, 32, 64, 128],
)
model = IRIS(config)
# Simulate training
B = 2
# Input image size: 128Γ—128
# DWT: 128β†’64 (Γ—12 channels), DownΓ—3: 64β†’32β†’16β†’8
# So latent is 8Γ—8 with latent_channels
images = torch.randn(B, 3, 128, 128)
text_tokens = torch.randn(B, 77, config.text_dim)
# Forward
t0 = time.time()
result = model.train_step(images, text_tokens, num_iterations=4)
t1 = time.time()
print(f" Loss: {result['loss'].item():.4f}")
print(f" Velocity loss: {result['velocity_loss']:.4f}")
print(f" KL loss: {result['kl_loss']:.4f}")
print(f" Mean t: {result['mean_t']:.3f}")
print(f" Time: {(t1-t0)*1000:.0f} ms")
# Backward
t0 = time.time()
result['loss'].backward()
t1 = time.time()
print(f" Backward time: {(t1-t0)*1000:.0f} ms")
# Check gradients
n_grads = sum(1 for p in model.parameters() if p.grad is not None)
n_params = sum(1 for p in model.parameters())
print(f" Gradients: {n_grads}/{n_params} params have gradients")
print(" βœ… PASSED")
return True
def test_generation():
"""Test full generation pipeline."""
print("\n" + "=" * 60)
print("Test 6: Image Generation Pipeline")
print("=" * 60)
config = IRISConfig(
latent_channels=8,
latent_spatial=8,
hidden_dim=256,
num_heads=4,
head_dim=64,
num_prelude_blocks=1,
num_core_layers=2,
num_coda_blocks=1,
default_iterations=4,
fourier_num_blocks=4,
recurrence_dim=128,
manhattan_window=8,
text_dim=768,
patch_size=2,
vae_channels=[16, 32, 64, 128],
)
model = IRIS(config)
model.eval()
B = 2
text_tokens = torch.randn(B, 77, config.text_dim)
# Generate with different settings
for steps, iters in [(1, 4), (4, 4), (4, 8)]:
t0 = time.time()
with torch.no_grad():
images = model.generate(
text_tokens,
num_steps=steps,
num_iterations=iters,
cfg_scale=1.0, # No CFG for speed test
seed=42
)
t1 = time.time()
print(f" steps={steps}, iters={iters}: shape={list(images.shape)}, "
f"range=[{images.min():.2f}, {images.max():.2f}], time={1000*(t1-t0):.0f}ms")
assert images.shape == (B, 3, 128, 128), f"Unexpected output shape: {images.shape}"
print(" βœ… PASSED")
return True
def test_adaptive_compute():
"""Test that different iteration counts produce different results."""
print("\n" + "=" * 60)
print("Test 7: Adaptive Compute Budget")
print("=" * 60)
config = IRISConfig(
latent_channels=8,
latent_spatial=8,
hidden_dim=256,
num_heads=4,
head_dim=64,
num_prelude_blocks=1,
num_core_layers=2,
num_coda_blocks=1,
default_iterations=4,
fourier_num_blocks=4,
recurrence_dim=128,
manhattan_window=8,
text_dim=768,
patch_size=2,
vae_channels=[16, 32, 64, 128],
)
model = IRIS(config)
model.eval()
text_tokens = torch.randn(1, 77, config.text_dim)
# For an untrained model with zero-init adaLN gates, the core has minimal effect.
# After training, different iterations WILL produce different outputs.
# For this test, initialize adaLN gates to non-zero to simulate a partially trained model.
with torch.no_grad():
model.generator.output_proj.weight.normal_(0, 0.02)
for name, param in model.generator.core.named_parameters():
if 'adaln' in name:
param.normal_(0, 0.1)
results = {}
for r in [2, 4, 8, 12]:
with torch.no_grad():
img = model.generate(text_tokens, num_steps=2, num_iterations=r,
cfg_scale=1.0, seed=42)
results[r] = img
# Check that different iterations give different results
diff_4_8 = (results[4] - results[8]).abs().mean().item()
diff_8_12 = (results[8] - results[12]).abs().mean().item()
diff_2_12 = (results[2] - results[12]).abs().mean().item()
print(f" Diff(r=4, r=8): {diff_4_8:.4f}")
print(f" Diff(r=8, r=12): {diff_8_12:.4f}")
print(f" Diff(r=2, r=12): {diff_2_12:.4f}")
print(f" More iterations β†’ more refinement: {'βœ…' if diff_2_12 > diff_8_12 else '⚠️'}")
# All should be different (model produces different outputs at different budgets)
assert diff_4_8 > 0, "r=4 and r=8 should differ"
assert diff_8_12 > 0, "r=8 and r=12 should differ"
print(" βœ… PASSED")
return True
def test_memory_profile():
"""Profile memory usage for mobile deployment."""
print("\n" + "=" * 60)
print("Test 8: Memory Profile for Mobile Deployment")
print("=" * 60)
for name, create_fn in [("IRIS-Tiny", create_iris_tiny),
("IRIS-Small", create_iris_small)]:
model = create_fn()
# Component-wise analysis
vae_params = sum(p.numel() for p in model.vae.parameters())
gen_params = sum(p.numel() for p in model.generator.parameters())
# Core block (shared) β€” this is the key
core_params = sum(p.numel() for p in model.generator.core.parameters())
prelude_params = sum(p.numel() for p in model.generator.prelude.parameters())
coda_params = sum(p.numel() for p in model.generator.coda.parameters())
vae_mb = vae_params * 2 / 1024 / 1024
gen_mb = gen_params * 2 / 1024 / 1024
core_mb = core_params * 2 / 1024 / 1024
# Estimate total inference memory (fp16)
model_mb = (vae_params + gen_params) * 2 / 1024 / 1024
text_enc_mb = 156 # CLIP-L/14 text encoder
activation_mb = 50 # Single iteration buffer
overhead_mb = 300 # OS + framework
total_mb = model_mb + text_enc_mb + activation_mb + overhead_mb
print(f"\n {name}:")
print(f" VAE: {vae_params:>10,} params = {vae_mb:>6.1f} MB")
print(f" Generator: {gen_params:>10,} params = {gen_mb:>6.1f} MB")
print(f" Prelude: {prelude_params:>10,}")
print(f" Core: {core_params:>10,} (shared, iterated r times)")
print(f" Coda: {coda_params:>10,}")
print(f" ────────────────────────────────")
print(f" Model total: {model_mb:>6.1f} MB (fp16)")
print(f" + CLIP-L/14: {text_enc_mb:>6.1f} MB")
print(f" + Activations: {activation_mb:>6.1f} MB")
print(f" + OS overhead: {overhead_mb:>6.1f} MB")
print(f" ═══════════════════════════════")
print(f" TOTAL INFERENCE: {total_mb:>6.1f} MB")
print(f" Fits in 3GB: {'βœ… YES' if total_mb < 3000 else '❌ NO'}")
print(f" Fits in 4GB: {'βœ… YES' if total_mb < 4000 else '❌ NO'}")
print("\n βœ… PASSED")
return True
def test_effective_depth():
"""Demonstrate the effective depth advantage."""
print("\n" + "=" * 60)
print("Test 9: Effective Depth Analysis")
print("=" * 60)
model = create_iris_small()
config = model.config
# Unique parameters
core_params = sum(p.numel() for p in model.generator.core.parameters())
total_unique = sum(p.numel() for p in model.parameters())
layers_per_iteration = config.num_core_layers
print(f" Architecture: Prelude({config.num_prelude_blocks}) β†’ "
f"Core({config.num_core_layers} layers Γ— r iters) β†’ "
f"Coda({config.num_coda_blocks})")
print(f" Unique params: {total_unique:,}")
print(f" Core params: {core_params:,} (shared)")
print()
for r in [4, 8, 12, 16]:
effective_layers = config.num_prelude_blocks + r * layers_per_iteration + config.num_coda_blocks
effective_params = total_unique + (r - 1) * core_params # Conceptual equivalent
print(f" r={r:2d}: {effective_layers} effective layers, "
f"~{effective_params/1e6:.0f}M effective params, "
f"from {total_unique/1e6:.0f}M unique")
print(f"\n β†’ 16Γ— iteration gives {(total_unique + 15*core_params)/total_unique:.1f}Γ— "
f"effective capacity from same model!")
print(" βœ… PASSED")
return True
if __name__ == "__main__":
print("πŸ”¬ IRIS Architecture Validation Suite")
print("=" * 60)
tests = [
test_wavelet_transform,
test_vae,
test_grfm,
test_generator_forward,
test_training_step,
test_generation,
test_adaptive_compute,
test_memory_profile,
test_effective_depth,
]
passed = 0
failed = 0
for test in tests:
try:
if test():
passed += 1
except Exception as e:
print(f" ❌ FAILED: {e}")
import traceback
traceback.print_exc()
failed += 1
print(f"\n{'=' * 60}")
print(f"Results: {passed} passed, {failed} failed out of {len(tests)} tests")
print(f"{'=' * 60}")
if failed > 0:
sys.exit(1)