Upload iris/test_all.py
Browse files- iris/test_all.py +221 -0
iris/test_all.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Comprehensive test suite for IRIS. 17 tests, all verified passing."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import time, sys, traceback, math
|
| 6 |
+
|
| 7 |
+
def test_module(name, fn):
|
| 8 |
+
print(f"\n{'='*60}\nTEST: {name}\n{'='*60}")
|
| 9 |
+
try:
|
| 10 |
+
fn()
|
| 11 |
+
print(f" PASSED")
|
| 12 |
+
return True
|
| 13 |
+
except Exception as e:
|
| 14 |
+
print(f" FAILED: {e}")
|
| 15 |
+
traceback.print_exc()
|
| 16 |
+
return False
|
| 17 |
+
|
| 18 |
+
def test_spectral_conv():
|
| 19 |
+
from iris.pde_ssm import SpectralConv2d
|
| 20 |
+
for H, W in [(4,4),(8,8),(16,16)]:
|
| 21 |
+
conv = SpectralConv2d(channels=32, modes_h=H//2, modes_w=W//2)
|
| 22 |
+
x = torch.randn(2, 32, H, W)
|
| 23 |
+
out = conv(x)
|
| 24 |
+
assert out.shape == x.shape
|
| 25 |
+
out.sum().backward()
|
| 26 |
+
assert conv.weight_pos.grad is not None and conv.weight_pos.grad.norm() > 0
|
| 27 |
+
assert not torch.isnan(out).any()
|
| 28 |
+
print(f" SpectralConv2d({H}x{W}): OK")
|
| 29 |
+
|
| 30 |
+
def test_token_differential():
|
| 31 |
+
from iris.pde_ssm import TokenDifferential
|
| 32 |
+
td = TokenDifferential(32)
|
| 33 |
+
x = torch.randn(2, 32, 4, 4)
|
| 34 |
+
assert torch.allclose(td(x), x, atol=1e-6)
|
| 35 |
+
td.alpha.data.fill_(1.0)
|
| 36 |
+
assert not torch.allclose(td(x), x)
|
| 37 |
+
print(" TokenDiff: identity at init, non-identity with alpha=1")
|
| 38 |
+
|
| 39 |
+
def test_pde_ssm_block():
|
| 40 |
+
from iris.pde_ssm import PDESSMBlock
|
| 41 |
+
for s in [4, 8]:
|
| 42 |
+
b = PDESSMBlock(dim=64, spatial_size=s)
|
| 43 |
+
x = torch.randn(2, s*s, 64)
|
| 44 |
+
out = b(x, s, s)
|
| 45 |
+
assert out.shape == x.shape
|
| 46 |
+
out.sum().backward()
|
| 47 |
+
for n, p in b.named_parameters():
|
| 48 |
+
if p.requires_grad:
|
| 49 |
+
assert p.grad is not None and not torch.isnan(p.grad).any()
|
| 50 |
+
print(f" PDESSMBlock(s={s}): OK")
|
| 51 |
+
|
| 52 |
+
def test_cross_attention():
|
| 53 |
+
from iris.blocks import MultiQueryCrossAttention
|
| 54 |
+
a = MultiQueryCrossAttention(dim=64, num_heads=4)
|
| 55 |
+
out = a(torch.randn(2,16,64), torch.randn(2,32,64))
|
| 56 |
+
assert out.shape == (2,16,64)
|
| 57 |
+
out.sum().backward()
|
| 58 |
+
assert a.k_proj.weight.numel() < a.q_proj.weight.numel()
|
| 59 |
+
print(f" CrossAttn MQA: OK, K/Q ratio = {a.q_proj.weight.numel()//a.k_proj.weight.numel()}x")
|
| 60 |
+
|
| 61 |
+
def test_self_attention():
|
| 62 |
+
from iris.blocks import MultiQuerySelfAttention
|
| 63 |
+
a = MultiQuerySelfAttention(dim=64, num_heads=4)
|
| 64 |
+
out = a(torch.randn(2,16,64), 4, 4)
|
| 65 |
+
assert out.shape == (2,16,64)
|
| 66 |
+
out.sum().backward()
|
| 67 |
+
print(" SelfAttn+2D RoPE: OK")
|
| 68 |
+
|
| 69 |
+
def test_rope_2d():
|
| 70 |
+
from iris.blocks import RotaryEmbedding2D
|
| 71 |
+
rope = RotaryEmbedding2D(dim=16)
|
| 72 |
+
x = torch.randn(2, 4, 16, 16)
|
| 73 |
+
out = rope(x, 4, 4)
|
| 74 |
+
assert out.shape == x.shape
|
| 75 |
+
assert abs(x.norm(dim=-1).mean() - out.norm(dim=-1).mean()) / x.norm(dim=-1).mean() < 0.1
|
| 76 |
+
print(" 2D RoPE: norm preserved")
|
| 77 |
+
|
| 78 |
+
def test_uib_ffn():
|
| 79 |
+
from iris.blocks import UIBFFN
|
| 80 |
+
f = UIBFFN(dim=64, expansion=2)
|
| 81 |
+
out = f(torch.randn(2,16,64), 4, 4)
|
| 82 |
+
assert out.shape == (2,16,64)
|
| 83 |
+
out.sum().backward()
|
| 84 |
+
assert f.dw_conv.groups == 128
|
| 85 |
+
print(" UIB-FFN: OK")
|
| 86 |
+
|
| 87 |
+
def test_timestep_embedding():
|
| 88 |
+
from iris.blocks import TimestepEmbedding
|
| 89 |
+
te = TimestepEmbedding(dim=64)
|
| 90 |
+
out = te(torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]))
|
| 91 |
+
assert out.shape == (5, 64) and not torch.isnan(out).any()
|
| 92 |
+
print(" TimestepEmbed: OK")
|
| 93 |
+
|
| 94 |
+
def test_patchify_unpatchify():
|
| 95 |
+
from iris.model import Patchify, Unpatchify
|
| 96 |
+
for ps in [2, 4]:
|
| 97 |
+
dim = 128 if ps==2 else 512
|
| 98 |
+
p, u = Patchify(32, dim, ps), Unpatchify(32, dim, ps)
|
| 99 |
+
z = torch.randn(2, 32, 16, 16)
|
| 100 |
+
tok, H, W = p(z)
|
| 101 |
+
assert tok.shape == (2, (16//ps)**2, dim)
|
| 102 |
+
assert u(tok, H, W).shape == z.shape
|
| 103 |
+
print(f" Patchify(ps={ps}): OK")
|
| 104 |
+
|
| 105 |
+
def test_tiny_decoder():
|
| 106 |
+
from iris.model import TinyDecoder
|
| 107 |
+
d = TinyDecoder(32, 3)
|
| 108 |
+
img = d(torch.randn(2, 32, 16, 16))
|
| 109 |
+
assert img.shape == (2, 3, 512, 512)
|
| 110 |
+
n = sum(p.numel() for p in d.parameters())
|
| 111 |
+
assert n < 2_000_000
|
| 112 |
+
print(f" TinyDecoder: {n:,} params, output {img.shape}")
|
| 113 |
+
|
| 114 |
+
def test_iris_forward():
|
| 115 |
+
from iris.model import IRIS
|
| 116 |
+
m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
|
| 117 |
+
z = torch.randn(2, 32, 16, 16)
|
| 118 |
+
t = torch.tensor([0.3, 0.7])
|
| 119 |
+
ctx = torch.randn(2, 8, 128)
|
| 120 |
+
v = m(z, t, ctx, num_iterations=2)
|
| 121 |
+
assert v.shape == z.shape and not torch.isnan(v).any()
|
| 122 |
+
v.sum().backward()
|
| 123 |
+
opt = torch.optim.SGD(m.parameters(), lr=0.01)
|
| 124 |
+
opt.step(); opt.zero_grad(set_to_none=True)
|
| 125 |
+
m(z, t, ctx, 2).sum().backward()
|
| 126 |
+
core_p = [(n,p) for n,p in m.named_parameters() if p.requires_grad and "tiny_decoder" not in n]
|
| 127 |
+
assert all(p.grad is not None and p.grad.norm()>1e-10 for _,p in core_p)
|
| 128 |
+
print(f" Forward OK, all {len(core_p)} core params have grad, total={m.count_params()['total']:,}")
|
| 129 |
+
|
| 130 |
+
def test_flow_matching_loss():
|
| 131 |
+
from iris.model import IRIS
|
| 132 |
+
from iris.flow_matching import flow_matching_loss
|
| 133 |
+
m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
|
| 134 |
+
l = flow_matching_loss(m, torch.randn(4,32,16,16)*2.5, torch.randn(4,8,128), num_iterations=2)
|
| 135 |
+
assert l["loss"].requires_grad and not torch.isnan(l["loss"]) and l["loss"].item() > 0
|
| 136 |
+
l["loss"].backward()
|
| 137 |
+
print(f" flow_loss={l['flow_loss'].item():.4f}")
|
| 138 |
+
|
| 139 |
+
def test_euler_sampling():
|
| 140 |
+
from iris.model import IRIS
|
| 141 |
+
from iris.flow_matching import euler_sample
|
| 142 |
+
m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
|
| 143 |
+
m.eval()
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
z = euler_sample(m, torch.randn(1,32,16,16), torch.randn(1,8,128), num_steps=5, num_iterations=2)
|
| 146 |
+
assert z.shape == (1,32,16,16) and not torch.isnan(z).any()
|
| 147 |
+
img = m.decode_latent(z)
|
| 148 |
+
assert img.shape == (1,3,512,512)
|
| 149 |
+
print(f" Euler sampling OK, decoded {img.shape}")
|
| 150 |
+
|
| 151 |
+
def test_gradient_checkpointing():
|
| 152 |
+
from iris.model import IRIS
|
| 153 |
+
from iris.flow_matching import flow_matching_loss
|
| 154 |
+
m1 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=False)
|
| 155 |
+
m2 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=True)
|
| 156 |
+
m2.load_state_dict(m1.state_dict())
|
| 157 |
+
torch.manual_seed(42)
|
| 158 |
+
z, ctx = torch.randn(2,32,16,16)*2.5, torch.randn(2,4,64)
|
| 159 |
+
torch.manual_seed(123)
|
| 160 |
+
l1 = flow_matching_loss(m1, z, ctx, num_iterations=3); l1["loss"].backward()
|
| 161 |
+
torch.manual_seed(123)
|
| 162 |
+
l2 = flow_matching_loss(m2, z, ctx, num_iterations=3); l2["loss"].backward()
|
| 163 |
+
diff = abs(l1["loss"].item() - l2["loss"].item())
|
| 164 |
+
maxg = max((p1.grad-p2.grad).abs().max().item() for (n1,p1),(n2,p2) in zip(m1.named_parameters(),m2.named_parameters()) if p1.grad is not None and p2.grad is not None)
|
| 165 |
+
assert diff < 1e-6 and maxg < 1e-4
|
| 166 |
+
print(f" Checkpointing: loss diff={diff:.8f}, max grad diff={maxg:.8f}")
|
| 167 |
+
|
| 168 |
+
def test_weight_sharing():
|
| 169 |
+
from iris.core import RefinementCore
|
| 170 |
+
c = RefinementCore(dim=64, num_blocks=3, num_heads=4, spatial_size=4, max_iterations=4, gradient_checkpointing=False)
|
| 171 |
+
x, ctx, t = torch.randn(1,16,64), torch.randn(1,4,64), torch.tensor([0.5])
|
| 172 |
+
o2 = c(x, ctx, t, 4, 4, num_iterations=2)
|
| 173 |
+
o4 = c(x, ctx, t, 4, 4, num_iterations=4)
|
| 174 |
+
assert not torch.allclose(o2, o4, atol=1e-3)
|
| 175 |
+
print(f" Weight sharing: {sum(p.numel() for p in c.parameters()):,} params (constant)")
|
| 176 |
+
|
| 177 |
+
def test_zero_init():
|
| 178 |
+
from iris.model import IRIS
|
| 179 |
+
m = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, gradient_checkpointing=False)
|
| 180 |
+
assert (m.unpatchify.proj.weight==0).all() and (m.unpatchify.proj.bias==0).all()
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
out = m(torch.randn(1,32,16,16), torch.tensor([0.5]), torch.randn(1,4,64), 1)
|
| 183 |
+
assert out.norm().item() < 1.0
|
| 184 |
+
print(f" Zero-init: output norm={out.norm().item():.6f}")
|
| 185 |
+
|
| 186 |
+
def test_training_stability():
|
| 187 |
+
from iris.model import IRIS
|
| 188 |
+
from iris.flow_matching import flow_matching_loss
|
| 189 |
+
m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
|
| 190 |
+
opt = torch.optim.AdamW(m.parameters(), lr=1e-3, weight_decay=0.01)
|
| 191 |
+
torch.manual_seed(0)
|
| 192 |
+
z, ctx = torch.randn(8,32,16,16)*2.5, torch.randn(8,8,128)
|
| 193 |
+
losses = []
|
| 194 |
+
m.train()
|
| 195 |
+
for s in range(100):
|
| 196 |
+
l = flow_matching_loss(m, z, ctx, num_iterations=2)
|
| 197 |
+
opt.zero_grad(set_to_none=True)
|
| 198 |
+
l["loss"].backward()
|
| 199 |
+
torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
|
| 200 |
+
opt.step()
|
| 201 |
+
losses.append(l["loss"].item())
|
| 202 |
+
if (s+1) % 25 == 0: print(f" Step {s+1}: loss={losses[-1]:.4f}")
|
| 203 |
+
f10, l10 = sum(losses[:10])/10, sum(losses[-10:])/10
|
| 204 |
+
print(f" Loss: {f10:.4f} -> {l10:.4f} ({(1-l10/f10)*100:.1f}%)")
|
| 205 |
+
assert l10 < f10 and not any(math.isnan(l) or math.isinf(l) for l in losses)
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
tests = [
|
| 209 |
+
("SpectralConv2d", test_spectral_conv), ("TokenDifferential", test_token_differential),
|
| 210 |
+
("PDESSMBlock", test_pde_ssm_block), ("CrossAttention (MQA)", test_cross_attention),
|
| 211 |
+
("SelfAttention (MQA+2D RoPE)", test_self_attention), ("2D RoPE", test_rope_2d),
|
| 212 |
+
("UIB-FFN", test_uib_ffn), ("TimestepEmbedding", test_timestep_embedding),
|
| 213 |
+
("Patchify/Unpatchify", test_patchify_unpatchify), ("TinyDecoder", test_tiny_decoder),
|
| 214 |
+
("IRIS Forward", test_iris_forward), ("Flow Matching Loss", test_flow_matching_loss),
|
| 215 |
+
("Euler Sampling", test_euler_sampling), ("Gradient Checkpointing", test_gradient_checkpointing),
|
| 216 |
+
("Weight Sharing", test_weight_sharing), ("Zero-Init Output", test_zero_init),
|
| 217 |
+
("Training Stability (100 steps)", test_training_stability),
|
| 218 |
+
]
|
| 219 |
+
passed = sum(1 for n,f in tests if test_module(n,f))
|
| 220 |
+
print(f"\n{'='*60}\nRESULTS: {passed}/{len(tests)} passed\n{'='*60}")
|
| 221 |
+
if passed < len(tests): sys.exit(1)
|