iris-image-gen / iris /test_all.py
asdf98's picture
Upload iris/test_all.py
fe73fcc verified
"""Comprehensive test suite for IRIS. 17 tests, all verified passing."""
import torch
import torch.nn.functional as F
import time, sys, traceback, math
def test_module(name, fn):
print(f"\n{'='*60}\nTEST: {name}\n{'='*60}")
try:
fn()
print(f" PASSED")
return True
except Exception as e:
print(f" FAILED: {e}")
traceback.print_exc()
return False
def test_spectral_conv():
from iris.pde_ssm import SpectralConv2d
for H, W in [(4,4),(8,8),(16,16)]:
conv = SpectralConv2d(channels=32, modes_h=H//2, modes_w=W//2)
x = torch.randn(2, 32, H, W)
out = conv(x)
assert out.shape == x.shape
out.sum().backward()
assert conv.weight_pos.grad is not None and conv.weight_pos.grad.norm() > 0
assert not torch.isnan(out).any()
print(f" SpectralConv2d({H}x{W}): OK")
def test_token_differential():
from iris.pde_ssm import TokenDifferential
td = TokenDifferential(32)
x = torch.randn(2, 32, 4, 4)
assert torch.allclose(td(x), x, atol=1e-6)
td.alpha.data.fill_(1.0)
assert not torch.allclose(td(x), x)
print(" TokenDiff: identity at init, non-identity with alpha=1")
def test_pde_ssm_block():
from iris.pde_ssm import PDESSMBlock
for s in [4, 8]:
b = PDESSMBlock(dim=64, spatial_size=s)
x = torch.randn(2, s*s, 64)
out = b(x, s, s)
assert out.shape == x.shape
out.sum().backward()
for n, p in b.named_parameters():
if p.requires_grad:
assert p.grad is not None and not torch.isnan(p.grad).any()
print(f" PDESSMBlock(s={s}): OK")
def test_cross_attention():
from iris.blocks import MultiQueryCrossAttention
a = MultiQueryCrossAttention(dim=64, num_heads=4)
out = a(torch.randn(2,16,64), torch.randn(2,32,64))
assert out.shape == (2,16,64)
out.sum().backward()
assert a.k_proj.weight.numel() < a.q_proj.weight.numel()
print(f" CrossAttn MQA: OK, K/Q ratio = {a.q_proj.weight.numel()//a.k_proj.weight.numel()}x")
def test_self_attention():
from iris.blocks import MultiQuerySelfAttention
a = MultiQuerySelfAttention(dim=64, num_heads=4)
out = a(torch.randn(2,16,64), 4, 4)
assert out.shape == (2,16,64)
out.sum().backward()
print(" SelfAttn+2D RoPE: OK")
def test_rope_2d():
from iris.blocks import RotaryEmbedding2D
rope = RotaryEmbedding2D(dim=16)
x = torch.randn(2, 4, 16, 16)
out = rope(x, 4, 4)
assert out.shape == x.shape
assert abs(x.norm(dim=-1).mean() - out.norm(dim=-1).mean()) / x.norm(dim=-1).mean() < 0.1
print(" 2D RoPE: norm preserved")
def test_uib_ffn():
from iris.blocks import UIBFFN
f = UIBFFN(dim=64, expansion=2)
out = f(torch.randn(2,16,64), 4, 4)
assert out.shape == (2,16,64)
out.sum().backward()
assert f.dw_conv.groups == 128
print(" UIB-FFN: OK")
def test_timestep_embedding():
from iris.blocks import TimestepEmbedding
te = TimestepEmbedding(dim=64)
out = te(torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]))
assert out.shape == (5, 64) and not torch.isnan(out).any()
print(" TimestepEmbed: OK")
def test_patchify_unpatchify():
from iris.model import Patchify, Unpatchify
for ps in [2, 4]:
dim = 128 if ps==2 else 512
p, u = Patchify(32, dim, ps), Unpatchify(32, dim, ps)
z = torch.randn(2, 32, 16, 16)
tok, H, W = p(z)
assert tok.shape == (2, (16//ps)**2, dim)
assert u(tok, H, W).shape == z.shape
print(f" Patchify(ps={ps}): OK")
def test_tiny_decoder():
from iris.model import TinyDecoder
d = TinyDecoder(32, 3)
img = d(torch.randn(2, 32, 16, 16))
assert img.shape == (2, 3, 512, 512)
n = sum(p.numel() for p in d.parameters())
assert n < 2_000_000
print(f" TinyDecoder: {n:,} params, output {img.shape}")
def test_iris_forward():
from iris.model import IRIS
m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
z = torch.randn(2, 32, 16, 16)
t = torch.tensor([0.3, 0.7])
ctx = torch.randn(2, 8, 128)
v = m(z, t, ctx, num_iterations=2)
assert v.shape == z.shape and not torch.isnan(v).any()
v.sum().backward()
opt = torch.optim.SGD(m.parameters(), lr=0.01)
opt.step(); opt.zero_grad(set_to_none=True)
m(z, t, ctx, 2).sum().backward()
core_p = [(n,p) for n,p in m.named_parameters() if p.requires_grad and "tiny_decoder" not in n]
assert all(p.grad is not None and p.grad.norm()>1e-10 for _,p in core_p)
print(f" Forward OK, all {len(core_p)} core params have grad, total={m.count_params()['total']:,}")
def test_flow_matching_loss():
from iris.model import IRIS
from iris.flow_matching import flow_matching_loss
m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
l = flow_matching_loss(m, torch.randn(4,32,16,16)*2.5, torch.randn(4,8,128), num_iterations=2)
assert l["loss"].requires_grad and not torch.isnan(l["loss"]) and l["loss"].item() > 0
l["loss"].backward()
print(f" flow_loss={l['flow_loss'].item():.4f}")
def test_euler_sampling():
from iris.model import IRIS
from iris.flow_matching import euler_sample
m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
m.eval()
with torch.no_grad():
z = euler_sample(m, torch.randn(1,32,16,16), torch.randn(1,8,128), num_steps=5, num_iterations=2)
assert z.shape == (1,32,16,16) and not torch.isnan(z).any()
img = m.decode_latent(z)
assert img.shape == (1,3,512,512)
print(f" Euler sampling OK, decoded {img.shape}")
def test_gradient_checkpointing():
from iris.model import IRIS
from iris.flow_matching import flow_matching_loss
m1 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=False)
m2 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=True)
m2.load_state_dict(m1.state_dict())
torch.manual_seed(42)
z, ctx = torch.randn(2,32,16,16)*2.5, torch.randn(2,4,64)
torch.manual_seed(123)
l1 = flow_matching_loss(m1, z, ctx, num_iterations=3); l1["loss"].backward()
torch.manual_seed(123)
l2 = flow_matching_loss(m2, z, ctx, num_iterations=3); l2["loss"].backward()
diff = abs(l1["loss"].item() - l2["loss"].item())
maxg = max((p1.grad-p2.grad).abs().max().item() for (n1,p1),(n2,p2) in zip(m1.named_parameters(),m2.named_parameters()) if p1.grad is not None and p2.grad is not None)
assert diff < 1e-6 and maxg < 1e-4
print(f" Checkpointing: loss diff={diff:.8f}, max grad diff={maxg:.8f}")
def test_weight_sharing():
from iris.core import RefinementCore
c = RefinementCore(dim=64, num_blocks=3, num_heads=4, spatial_size=4, max_iterations=4, gradient_checkpointing=False)
x, ctx, t = torch.randn(1,16,64), torch.randn(1,4,64), torch.tensor([0.5])
o2 = c(x, ctx, t, 4, 4, num_iterations=2)
o4 = c(x, ctx, t, 4, 4, num_iterations=4)
assert not torch.allclose(o2, o4, atol=1e-3)
print(f" Weight sharing: {sum(p.numel() for p in c.parameters()):,} params (constant)")
def test_zero_init():
from iris.model import IRIS
m = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, gradient_checkpointing=False)
assert (m.unpatchify.proj.weight==0).all() and (m.unpatchify.proj.bias==0).all()
with torch.no_grad():
out = m(torch.randn(1,32,16,16), torch.tensor([0.5]), torch.randn(1,4,64), 1)
assert out.norm().item() < 1.0
print(f" Zero-init: output norm={out.norm().item():.6f}")
def test_training_stability():
from iris.model import IRIS
from iris.flow_matching import flow_matching_loss
m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
opt = torch.optim.AdamW(m.parameters(), lr=1e-3, weight_decay=0.01)
torch.manual_seed(0)
z, ctx = torch.randn(8,32,16,16)*2.5, torch.randn(8,8,128)
losses = []
m.train()
for s in range(100):
l = flow_matching_loss(m, z, ctx, num_iterations=2)
opt.zero_grad(set_to_none=True)
l["loss"].backward()
torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
opt.step()
losses.append(l["loss"].item())
if (s+1) % 25 == 0: print(f" Step {s+1}: loss={losses[-1]:.4f}")
f10, l10 = sum(losses[:10])/10, sum(losses[-10:])/10
print(f" Loss: {f10:.4f} -> {l10:.4f} ({(1-l10/f10)*100:.1f}%)")
assert l10 < f10 and not any(math.isnan(l) or math.isinf(l) for l in losses)
if __name__ == "__main__":
tests = [
("SpectralConv2d", test_spectral_conv), ("TokenDifferential", test_token_differential),
("PDESSMBlock", test_pde_ssm_block), ("CrossAttention (MQA)", test_cross_attention),
("SelfAttention (MQA+2D RoPE)", test_self_attention), ("2D RoPE", test_rope_2d),
("UIB-FFN", test_uib_ffn), ("TimestepEmbedding", test_timestep_embedding),
("Patchify/Unpatchify", test_patchify_unpatchify), ("TinyDecoder", test_tiny_decoder),
("IRIS Forward", test_iris_forward), ("Flow Matching Loss", test_flow_matching_loss),
("Euler Sampling", test_euler_sampling), ("Gradient Checkpointing", test_gradient_checkpointing),
("Weight Sharing", test_weight_sharing), ("Zero-Init Output", test_zero_init),
("Training Stability (100 steps)", test_training_stability),
]
passed = sum(1 for n,f in tests if test_module(n,f))
print(f"\n{'='*60}\nRESULTS: {passed}/{len(tests)} passed\n{'='*60}")
if passed < len(tests): sys.exit(1)