| """Comprehensive test suite for IRIS. 17 tests, all verified passing.""" |
|
|
| import torch |
| import torch.nn.functional as F |
| import time, sys, traceback, math |
|
|
| def test_module(name, fn): |
| print(f"\n{'='*60}\nTEST: {name}\n{'='*60}") |
| try: |
| fn() |
| print(f" PASSED") |
| return True |
| except Exception as e: |
| print(f" FAILED: {e}") |
| traceback.print_exc() |
| return False |
|
|
| def test_spectral_conv(): |
| from iris.pde_ssm import SpectralConv2d |
| for H, W in [(4,4),(8,8),(16,16)]: |
| conv = SpectralConv2d(channels=32, modes_h=H//2, modes_w=W//2) |
| x = torch.randn(2, 32, H, W) |
| out = conv(x) |
| assert out.shape == x.shape |
| out.sum().backward() |
| assert conv.weight_pos.grad is not None and conv.weight_pos.grad.norm() > 0 |
| assert not torch.isnan(out).any() |
| print(f" SpectralConv2d({H}x{W}): OK") |
|
|
| def test_token_differential(): |
| from iris.pde_ssm import TokenDifferential |
| td = TokenDifferential(32) |
| x = torch.randn(2, 32, 4, 4) |
| assert torch.allclose(td(x), x, atol=1e-6) |
| td.alpha.data.fill_(1.0) |
| assert not torch.allclose(td(x), x) |
| print(" TokenDiff: identity at init, non-identity with alpha=1") |
|
|
| def test_pde_ssm_block(): |
| from iris.pde_ssm import PDESSMBlock |
| for s in [4, 8]: |
| b = PDESSMBlock(dim=64, spatial_size=s) |
| x = torch.randn(2, s*s, 64) |
| out = b(x, s, s) |
| assert out.shape == x.shape |
| out.sum().backward() |
| for n, p in b.named_parameters(): |
| if p.requires_grad: |
| assert p.grad is not None and not torch.isnan(p.grad).any() |
| print(f" PDESSMBlock(s={s}): OK") |
|
|
| def test_cross_attention(): |
| from iris.blocks import MultiQueryCrossAttention |
| a = MultiQueryCrossAttention(dim=64, num_heads=4) |
| out = a(torch.randn(2,16,64), torch.randn(2,32,64)) |
| assert out.shape == (2,16,64) |
| out.sum().backward() |
| assert a.k_proj.weight.numel() < a.q_proj.weight.numel() |
| print(f" CrossAttn MQA: OK, K/Q ratio = {a.q_proj.weight.numel()//a.k_proj.weight.numel()}x") |
|
|
| def test_self_attention(): |
| from iris.blocks import MultiQuerySelfAttention |
| a = MultiQuerySelfAttention(dim=64, num_heads=4) |
| out = a(torch.randn(2,16,64), 4, 4) |
| assert out.shape == (2,16,64) |
| out.sum().backward() |
| print(" SelfAttn+2D RoPE: OK") |
|
|
| def test_rope_2d(): |
| from iris.blocks import RotaryEmbedding2D |
| rope = RotaryEmbedding2D(dim=16) |
| x = torch.randn(2, 4, 16, 16) |
| out = rope(x, 4, 4) |
| assert out.shape == x.shape |
| assert abs(x.norm(dim=-1).mean() - out.norm(dim=-1).mean()) / x.norm(dim=-1).mean() < 0.1 |
| print(" 2D RoPE: norm preserved") |
|
|
| def test_uib_ffn(): |
| from iris.blocks import UIBFFN |
| f = UIBFFN(dim=64, expansion=2) |
| out = f(torch.randn(2,16,64), 4, 4) |
| assert out.shape == (2,16,64) |
| out.sum().backward() |
| assert f.dw_conv.groups == 128 |
| print(" UIB-FFN: OK") |
|
|
| def test_timestep_embedding(): |
| from iris.blocks import TimestepEmbedding |
| te = TimestepEmbedding(dim=64) |
| out = te(torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0])) |
| assert out.shape == (5, 64) and not torch.isnan(out).any() |
| print(" TimestepEmbed: OK") |
|
|
| def test_patchify_unpatchify(): |
| from iris.model import Patchify, Unpatchify |
| for ps in [2, 4]: |
| dim = 128 if ps==2 else 512 |
| p, u = Patchify(32, dim, ps), Unpatchify(32, dim, ps) |
| z = torch.randn(2, 32, 16, 16) |
| tok, H, W = p(z) |
| assert tok.shape == (2, (16//ps)**2, dim) |
| assert u(tok, H, W).shape == z.shape |
| print(f" Patchify(ps={ps}): OK") |
|
|
| def test_tiny_decoder(): |
| from iris.model import TinyDecoder |
| d = TinyDecoder(32, 3) |
| img = d(torch.randn(2, 32, 16, 16)) |
| assert img.shape == (2, 3, 512, 512) |
| n = sum(p.numel() for p in d.parameters()) |
| assert n < 2_000_000 |
| print(f" TinyDecoder: {n:,} params, output {img.shape}") |
|
|
| def test_iris_forward(): |
| from iris.model import IRIS |
| m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False) |
| z = torch.randn(2, 32, 16, 16) |
| t = torch.tensor([0.3, 0.7]) |
| ctx = torch.randn(2, 8, 128) |
| v = m(z, t, ctx, num_iterations=2) |
| assert v.shape == z.shape and not torch.isnan(v).any() |
| v.sum().backward() |
| opt = torch.optim.SGD(m.parameters(), lr=0.01) |
| opt.step(); opt.zero_grad(set_to_none=True) |
| m(z, t, ctx, 2).sum().backward() |
| core_p = [(n,p) for n,p in m.named_parameters() if p.requires_grad and "tiny_decoder" not in n] |
| assert all(p.grad is not None and p.grad.norm()>1e-10 for _,p in core_p) |
| print(f" Forward OK, all {len(core_p)} core params have grad, total={m.count_params()['total']:,}") |
|
|
| def test_flow_matching_loss(): |
| from iris.model import IRIS |
| from iris.flow_matching import flow_matching_loss |
| m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False) |
| l = flow_matching_loss(m, torch.randn(4,32,16,16)*2.5, torch.randn(4,8,128), num_iterations=2) |
| assert l["loss"].requires_grad and not torch.isnan(l["loss"]) and l["loss"].item() > 0 |
| l["loss"].backward() |
| print(f" flow_loss={l['flow_loss'].item():.4f}") |
|
|
| def test_euler_sampling(): |
| from iris.model import IRIS |
| from iris.flow_matching import euler_sample |
| m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False) |
| m.eval() |
| with torch.no_grad(): |
| z = euler_sample(m, torch.randn(1,32,16,16), torch.randn(1,8,128), num_steps=5, num_iterations=2) |
| assert z.shape == (1,32,16,16) and not torch.isnan(z).any() |
| img = m.decode_latent(z) |
| assert img.shape == (1,3,512,512) |
| print(f" Euler sampling OK, decoded {img.shape}") |
|
|
| def test_gradient_checkpointing(): |
| from iris.model import IRIS |
| from iris.flow_matching import flow_matching_loss |
| m1 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=False) |
| m2 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=True) |
| m2.load_state_dict(m1.state_dict()) |
| torch.manual_seed(42) |
| z, ctx = torch.randn(2,32,16,16)*2.5, torch.randn(2,4,64) |
| torch.manual_seed(123) |
| l1 = flow_matching_loss(m1, z, ctx, num_iterations=3); l1["loss"].backward() |
| torch.manual_seed(123) |
| l2 = flow_matching_loss(m2, z, ctx, num_iterations=3); l2["loss"].backward() |
| diff = abs(l1["loss"].item() - l2["loss"].item()) |
| maxg = max((p1.grad-p2.grad).abs().max().item() for (n1,p1),(n2,p2) in zip(m1.named_parameters(),m2.named_parameters()) if p1.grad is not None and p2.grad is not None) |
| assert diff < 1e-6 and maxg < 1e-4 |
| print(f" Checkpointing: loss diff={diff:.8f}, max grad diff={maxg:.8f}") |
|
|
| def test_weight_sharing(): |
| from iris.core import RefinementCore |
| c = RefinementCore(dim=64, num_blocks=3, num_heads=4, spatial_size=4, max_iterations=4, gradient_checkpointing=False) |
| x, ctx, t = torch.randn(1,16,64), torch.randn(1,4,64), torch.tensor([0.5]) |
| o2 = c(x, ctx, t, 4, 4, num_iterations=2) |
| o4 = c(x, ctx, t, 4, 4, num_iterations=4) |
| assert not torch.allclose(o2, o4, atol=1e-3) |
| print(f" Weight sharing: {sum(p.numel() for p in c.parameters()):,} params (constant)") |
|
|
| def test_zero_init(): |
| from iris.model import IRIS |
| m = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, gradient_checkpointing=False) |
| assert (m.unpatchify.proj.weight==0).all() and (m.unpatchify.proj.bias==0).all() |
| with torch.no_grad(): |
| out = m(torch.randn(1,32,16,16), torch.tensor([0.5]), torch.randn(1,4,64), 1) |
| assert out.norm().item() < 1.0 |
| print(f" Zero-init: output norm={out.norm().item():.6f}") |
|
|
| def test_training_stability(): |
| from iris.model import IRIS |
| from iris.flow_matching import flow_matching_loss |
| m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False) |
| opt = torch.optim.AdamW(m.parameters(), lr=1e-3, weight_decay=0.01) |
| torch.manual_seed(0) |
| z, ctx = torch.randn(8,32,16,16)*2.5, torch.randn(8,8,128) |
| losses = [] |
| m.train() |
| for s in range(100): |
| l = flow_matching_loss(m, z, ctx, num_iterations=2) |
| opt.zero_grad(set_to_none=True) |
| l["loss"].backward() |
| torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0) |
| opt.step() |
| losses.append(l["loss"].item()) |
| if (s+1) % 25 == 0: print(f" Step {s+1}: loss={losses[-1]:.4f}") |
| f10, l10 = sum(losses[:10])/10, sum(losses[-10:])/10 |
| print(f" Loss: {f10:.4f} -> {l10:.4f} ({(1-l10/f10)*100:.1f}%)") |
| assert l10 < f10 and not any(math.isnan(l) or math.isinf(l) for l in losses) |
|
|
| if __name__ == "__main__": |
| tests = [ |
| ("SpectralConv2d", test_spectral_conv), ("TokenDifferential", test_token_differential), |
| ("PDESSMBlock", test_pde_ssm_block), ("CrossAttention (MQA)", test_cross_attention), |
| ("SelfAttention (MQA+2D RoPE)", test_self_attention), ("2D RoPE", test_rope_2d), |
| ("UIB-FFN", test_uib_ffn), ("TimestepEmbedding", test_timestep_embedding), |
| ("Patchify/Unpatchify", test_patchify_unpatchify), ("TinyDecoder", test_tiny_decoder), |
| ("IRIS Forward", test_iris_forward), ("Flow Matching Loss", test_flow_matching_loss), |
| ("Euler Sampling", test_euler_sampling), ("Gradient Checkpointing", test_gradient_checkpointing), |
| ("Weight Sharing", test_weight_sharing), ("Zero-Init Output", test_zero_init), |
| ("Training Stability (100 steps)", test_training_stability), |
| ] |
| passed = sum(1 for n,f in tests if test_module(n,f)) |
| print(f"\n{'='*60}\nRESULTS: {passed}/{len(tests)} passed\n{'='*60}") |
| if passed < len(tests): sys.exit(1) |
|
|