"""Comprehensive test suite for IRIS. 17 tests, all verified passing.""" import torch import torch.nn.functional as F import time, sys, traceback, math def test_module(name, fn): print(f"\n{'='*60}\nTEST: {name}\n{'='*60}") try: fn() print(f" PASSED") return True except Exception as e: print(f" FAILED: {e}") traceback.print_exc() return False def test_spectral_conv(): from iris.pde_ssm import SpectralConv2d for H, W in [(4,4),(8,8),(16,16)]: conv = SpectralConv2d(channels=32, modes_h=H//2, modes_w=W//2) x = torch.randn(2, 32, H, W) out = conv(x) assert out.shape == x.shape out.sum().backward() assert conv.weight_pos.grad is not None and conv.weight_pos.grad.norm() > 0 assert not torch.isnan(out).any() print(f" SpectralConv2d({H}x{W}): OK") def test_token_differential(): from iris.pde_ssm import TokenDifferential td = TokenDifferential(32) x = torch.randn(2, 32, 4, 4) assert torch.allclose(td(x), x, atol=1e-6) td.alpha.data.fill_(1.0) assert not torch.allclose(td(x), x) print(" TokenDiff: identity at init, non-identity with alpha=1") def test_pde_ssm_block(): from iris.pde_ssm import PDESSMBlock for s in [4, 8]: b = PDESSMBlock(dim=64, spatial_size=s) x = torch.randn(2, s*s, 64) out = b(x, s, s) assert out.shape == x.shape out.sum().backward() for n, p in b.named_parameters(): if p.requires_grad: assert p.grad is not None and not torch.isnan(p.grad).any() print(f" PDESSMBlock(s={s}): OK") def test_cross_attention(): from iris.blocks import MultiQueryCrossAttention a = MultiQueryCrossAttention(dim=64, num_heads=4) out = a(torch.randn(2,16,64), torch.randn(2,32,64)) assert out.shape == (2,16,64) out.sum().backward() assert a.k_proj.weight.numel() < a.q_proj.weight.numel() print(f" CrossAttn MQA: OK, K/Q ratio = {a.q_proj.weight.numel()//a.k_proj.weight.numel()}x") def test_self_attention(): from iris.blocks import MultiQuerySelfAttention a = MultiQuerySelfAttention(dim=64, num_heads=4) out = a(torch.randn(2,16,64), 4, 4) assert out.shape == (2,16,64) out.sum().backward() print(" SelfAttn+2D RoPE: OK") def test_rope_2d(): from iris.blocks import RotaryEmbedding2D rope = RotaryEmbedding2D(dim=16) x = torch.randn(2, 4, 16, 16) out = rope(x, 4, 4) assert out.shape == x.shape assert abs(x.norm(dim=-1).mean() - out.norm(dim=-1).mean()) / x.norm(dim=-1).mean() < 0.1 print(" 2D RoPE: norm preserved") def test_uib_ffn(): from iris.blocks import UIBFFN f = UIBFFN(dim=64, expansion=2) out = f(torch.randn(2,16,64), 4, 4) assert out.shape == (2,16,64) out.sum().backward() assert f.dw_conv.groups == 128 print(" UIB-FFN: OK") def test_timestep_embedding(): from iris.blocks import TimestepEmbedding te = TimestepEmbedding(dim=64) out = te(torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0])) assert out.shape == (5, 64) and not torch.isnan(out).any() print(" TimestepEmbed: OK") def test_patchify_unpatchify(): from iris.model import Patchify, Unpatchify for ps in [2, 4]: dim = 128 if ps==2 else 512 p, u = Patchify(32, dim, ps), Unpatchify(32, dim, ps) z = torch.randn(2, 32, 16, 16) tok, H, W = p(z) assert tok.shape == (2, (16//ps)**2, dim) assert u(tok, H, W).shape == z.shape print(f" Patchify(ps={ps}): OK") def test_tiny_decoder(): from iris.model import TinyDecoder d = TinyDecoder(32, 3) img = d(torch.randn(2, 32, 16, 16)) assert img.shape == (2, 3, 512, 512) n = sum(p.numel() for p in d.parameters()) assert n < 2_000_000 print(f" TinyDecoder: {n:,} params, output {img.shape}") def test_iris_forward(): from iris.model import IRIS m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False) z = torch.randn(2, 32, 16, 16) t = torch.tensor([0.3, 0.7]) ctx = torch.randn(2, 8, 128) v = m(z, t, ctx, num_iterations=2) assert v.shape == z.shape and not torch.isnan(v).any() v.sum().backward() opt = torch.optim.SGD(m.parameters(), lr=0.01) opt.step(); opt.zero_grad(set_to_none=True) m(z, t, ctx, 2).sum().backward() core_p = [(n,p) for n,p in m.named_parameters() if p.requires_grad and "tiny_decoder" not in n] assert all(p.grad is not None and p.grad.norm()>1e-10 for _,p in core_p) print(f" Forward OK, all {len(core_p)} core params have grad, total={m.count_params()['total']:,}") def test_flow_matching_loss(): from iris.model import IRIS from iris.flow_matching import flow_matching_loss m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False) l = flow_matching_loss(m, torch.randn(4,32,16,16)*2.5, torch.randn(4,8,128), num_iterations=2) assert l["loss"].requires_grad and not torch.isnan(l["loss"]) and l["loss"].item() > 0 l["loss"].backward() print(f" flow_loss={l['flow_loss'].item():.4f}") def test_euler_sampling(): from iris.model import IRIS from iris.flow_matching import euler_sample m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False) m.eval() with torch.no_grad(): z = euler_sample(m, torch.randn(1,32,16,16), torch.randn(1,8,128), num_steps=5, num_iterations=2) assert z.shape == (1,32,16,16) and not torch.isnan(z).any() img = m.decode_latent(z) assert img.shape == (1,3,512,512) print(f" Euler sampling OK, decoded {img.shape}") def test_gradient_checkpointing(): from iris.model import IRIS from iris.flow_matching import flow_matching_loss m1 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=False) m2 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=True) m2.load_state_dict(m1.state_dict()) torch.manual_seed(42) z, ctx = torch.randn(2,32,16,16)*2.5, torch.randn(2,4,64) torch.manual_seed(123) l1 = flow_matching_loss(m1, z, ctx, num_iterations=3); l1["loss"].backward() torch.manual_seed(123) l2 = flow_matching_loss(m2, z, ctx, num_iterations=3); l2["loss"].backward() diff = abs(l1["loss"].item() - l2["loss"].item()) maxg = max((p1.grad-p2.grad).abs().max().item() for (n1,p1),(n2,p2) in zip(m1.named_parameters(),m2.named_parameters()) if p1.grad is not None and p2.grad is not None) assert diff < 1e-6 and maxg < 1e-4 print(f" Checkpointing: loss diff={diff:.8f}, max grad diff={maxg:.8f}") def test_weight_sharing(): from iris.core import RefinementCore c = RefinementCore(dim=64, num_blocks=3, num_heads=4, spatial_size=4, max_iterations=4, gradient_checkpointing=False) x, ctx, t = torch.randn(1,16,64), torch.randn(1,4,64), torch.tensor([0.5]) o2 = c(x, ctx, t, 4, 4, num_iterations=2) o4 = c(x, ctx, t, 4, 4, num_iterations=4) assert not torch.allclose(o2, o4, atol=1e-3) print(f" Weight sharing: {sum(p.numel() for p in c.parameters()):,} params (constant)") def test_zero_init(): from iris.model import IRIS m = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, gradient_checkpointing=False) assert (m.unpatchify.proj.weight==0).all() and (m.unpatchify.proj.bias==0).all() with torch.no_grad(): out = m(torch.randn(1,32,16,16), torch.tensor([0.5]), torch.randn(1,4,64), 1) assert out.norm().item() < 1.0 print(f" Zero-init: output norm={out.norm().item():.6f}") def test_training_stability(): from iris.model import IRIS from iris.flow_matching import flow_matching_loss m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False) opt = torch.optim.AdamW(m.parameters(), lr=1e-3, weight_decay=0.01) torch.manual_seed(0) z, ctx = torch.randn(8,32,16,16)*2.5, torch.randn(8,8,128) losses = [] m.train() for s in range(100): l = flow_matching_loss(m, z, ctx, num_iterations=2) opt.zero_grad(set_to_none=True) l["loss"].backward() torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0) opt.step() losses.append(l["loss"].item()) if (s+1) % 25 == 0: print(f" Step {s+1}: loss={losses[-1]:.4f}") f10, l10 = sum(losses[:10])/10, sum(losses[-10:])/10 print(f" Loss: {f10:.4f} -> {l10:.4f} ({(1-l10/f10)*100:.1f}%)") assert l10 < f10 and not any(math.isnan(l) or math.isinf(l) for l in losses) if __name__ == "__main__": tests = [ ("SpectralConv2d", test_spectral_conv), ("TokenDifferential", test_token_differential), ("PDESSMBlock", test_pde_ssm_block), ("CrossAttention (MQA)", test_cross_attention), ("SelfAttention (MQA+2D RoPE)", test_self_attention), ("2D RoPE", test_rope_2d), ("UIB-FFN", test_uib_ffn), ("TimestepEmbedding", test_timestep_embedding), ("Patchify/Unpatchify", test_patchify_unpatchify), ("TinyDecoder", test_tiny_decoder), ("IRIS Forward", test_iris_forward), ("Flow Matching Loss", test_flow_matching_loss), ("Euler Sampling", test_euler_sampling), ("Gradient Checkpointing", test_gradient_checkpointing), ("Weight Sharing", test_weight_sharing), ("Zero-Init Output", test_zero_init), ("Training Stability (100 steps)", test_training_stability), ] passed = sum(1 for n,f in tests if test_module(n,f)) print(f"\n{'='*60}\nRESULTS: {passed}/{len(tests)} passed\n{'='*60}") if passed < len(tests): sys.exit(1)