asdf98 commited on
Commit
fe73fcc
·
verified ·
1 Parent(s): 5a3f8af

Upload iris/test_all.py

Browse files
Files changed (1) hide show
  1. iris/test_all.py +221 -0
iris/test_all.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Comprehensive test suite for IRIS. 17 tests, all verified passing."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import time, sys, traceback, math
6
+
7
+ def test_module(name, fn):
8
+ print(f"\n{'='*60}\nTEST: {name}\n{'='*60}")
9
+ try:
10
+ fn()
11
+ print(f" PASSED")
12
+ return True
13
+ except Exception as e:
14
+ print(f" FAILED: {e}")
15
+ traceback.print_exc()
16
+ return False
17
+
18
+ def test_spectral_conv():
19
+ from iris.pde_ssm import SpectralConv2d
20
+ for H, W in [(4,4),(8,8),(16,16)]:
21
+ conv = SpectralConv2d(channels=32, modes_h=H//2, modes_w=W//2)
22
+ x = torch.randn(2, 32, H, W)
23
+ out = conv(x)
24
+ assert out.shape == x.shape
25
+ out.sum().backward()
26
+ assert conv.weight_pos.grad is not None and conv.weight_pos.grad.norm() > 0
27
+ assert not torch.isnan(out).any()
28
+ print(f" SpectralConv2d({H}x{W}): OK")
29
+
30
+ def test_token_differential():
31
+ from iris.pde_ssm import TokenDifferential
32
+ td = TokenDifferential(32)
33
+ x = torch.randn(2, 32, 4, 4)
34
+ assert torch.allclose(td(x), x, atol=1e-6)
35
+ td.alpha.data.fill_(1.0)
36
+ assert not torch.allclose(td(x), x)
37
+ print(" TokenDiff: identity at init, non-identity with alpha=1")
38
+
39
+ def test_pde_ssm_block():
40
+ from iris.pde_ssm import PDESSMBlock
41
+ for s in [4, 8]:
42
+ b = PDESSMBlock(dim=64, spatial_size=s)
43
+ x = torch.randn(2, s*s, 64)
44
+ out = b(x, s, s)
45
+ assert out.shape == x.shape
46
+ out.sum().backward()
47
+ for n, p in b.named_parameters():
48
+ if p.requires_grad:
49
+ assert p.grad is not None and not torch.isnan(p.grad).any()
50
+ print(f" PDESSMBlock(s={s}): OK")
51
+
52
+ def test_cross_attention():
53
+ from iris.blocks import MultiQueryCrossAttention
54
+ a = MultiQueryCrossAttention(dim=64, num_heads=4)
55
+ out = a(torch.randn(2,16,64), torch.randn(2,32,64))
56
+ assert out.shape == (2,16,64)
57
+ out.sum().backward()
58
+ assert a.k_proj.weight.numel() < a.q_proj.weight.numel()
59
+ print(f" CrossAttn MQA: OK, K/Q ratio = {a.q_proj.weight.numel()//a.k_proj.weight.numel()}x")
60
+
61
+ def test_self_attention():
62
+ from iris.blocks import MultiQuerySelfAttention
63
+ a = MultiQuerySelfAttention(dim=64, num_heads=4)
64
+ out = a(torch.randn(2,16,64), 4, 4)
65
+ assert out.shape == (2,16,64)
66
+ out.sum().backward()
67
+ print(" SelfAttn+2D RoPE: OK")
68
+
69
+ def test_rope_2d():
70
+ from iris.blocks import RotaryEmbedding2D
71
+ rope = RotaryEmbedding2D(dim=16)
72
+ x = torch.randn(2, 4, 16, 16)
73
+ out = rope(x, 4, 4)
74
+ assert out.shape == x.shape
75
+ assert abs(x.norm(dim=-1).mean() - out.norm(dim=-1).mean()) / x.norm(dim=-1).mean() < 0.1
76
+ print(" 2D RoPE: norm preserved")
77
+
78
+ def test_uib_ffn():
79
+ from iris.blocks import UIBFFN
80
+ f = UIBFFN(dim=64, expansion=2)
81
+ out = f(torch.randn(2,16,64), 4, 4)
82
+ assert out.shape == (2,16,64)
83
+ out.sum().backward()
84
+ assert f.dw_conv.groups == 128
85
+ print(" UIB-FFN: OK")
86
+
87
+ def test_timestep_embedding():
88
+ from iris.blocks import TimestepEmbedding
89
+ te = TimestepEmbedding(dim=64)
90
+ out = te(torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]))
91
+ assert out.shape == (5, 64) and not torch.isnan(out).any()
92
+ print(" TimestepEmbed: OK")
93
+
94
+ def test_patchify_unpatchify():
95
+ from iris.model import Patchify, Unpatchify
96
+ for ps in [2, 4]:
97
+ dim = 128 if ps==2 else 512
98
+ p, u = Patchify(32, dim, ps), Unpatchify(32, dim, ps)
99
+ z = torch.randn(2, 32, 16, 16)
100
+ tok, H, W = p(z)
101
+ assert tok.shape == (2, (16//ps)**2, dim)
102
+ assert u(tok, H, W).shape == z.shape
103
+ print(f" Patchify(ps={ps}): OK")
104
+
105
+ def test_tiny_decoder():
106
+ from iris.model import TinyDecoder
107
+ d = TinyDecoder(32, 3)
108
+ img = d(torch.randn(2, 32, 16, 16))
109
+ assert img.shape == (2, 3, 512, 512)
110
+ n = sum(p.numel() for p in d.parameters())
111
+ assert n < 2_000_000
112
+ print(f" TinyDecoder: {n:,} params, output {img.shape}")
113
+
114
+ def test_iris_forward():
115
+ from iris.model import IRIS
116
+ m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
117
+ z = torch.randn(2, 32, 16, 16)
118
+ t = torch.tensor([0.3, 0.7])
119
+ ctx = torch.randn(2, 8, 128)
120
+ v = m(z, t, ctx, num_iterations=2)
121
+ assert v.shape == z.shape and not torch.isnan(v).any()
122
+ v.sum().backward()
123
+ opt = torch.optim.SGD(m.parameters(), lr=0.01)
124
+ opt.step(); opt.zero_grad(set_to_none=True)
125
+ m(z, t, ctx, 2).sum().backward()
126
+ core_p = [(n,p) for n,p in m.named_parameters() if p.requires_grad and "tiny_decoder" not in n]
127
+ assert all(p.grad is not None and p.grad.norm()>1e-10 for _,p in core_p)
128
+ print(f" Forward OK, all {len(core_p)} core params have grad, total={m.count_params()['total']:,}")
129
+
130
+ def test_flow_matching_loss():
131
+ from iris.model import IRIS
132
+ from iris.flow_matching import flow_matching_loss
133
+ m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
134
+ l = flow_matching_loss(m, torch.randn(4,32,16,16)*2.5, torch.randn(4,8,128), num_iterations=2)
135
+ assert l["loss"].requires_grad and not torch.isnan(l["loss"]) and l["loss"].item() > 0
136
+ l["loss"].backward()
137
+ print(f" flow_loss={l['flow_loss'].item():.4f}")
138
+
139
+ def test_euler_sampling():
140
+ from iris.model import IRIS
141
+ from iris.flow_matching import euler_sample
142
+ m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
143
+ m.eval()
144
+ with torch.no_grad():
145
+ z = euler_sample(m, torch.randn(1,32,16,16), torch.randn(1,8,128), num_steps=5, num_iterations=2)
146
+ assert z.shape == (1,32,16,16) and not torch.isnan(z).any()
147
+ img = m.decode_latent(z)
148
+ assert img.shape == (1,3,512,512)
149
+ print(f" Euler sampling OK, decoded {img.shape}")
150
+
151
+ def test_gradient_checkpointing():
152
+ from iris.model import IRIS
153
+ from iris.flow_matching import flow_matching_loss
154
+ m1 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=False)
155
+ m2 = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, max_iterations=4, gradient_checkpointing=True)
156
+ m2.load_state_dict(m1.state_dict())
157
+ torch.manual_seed(42)
158
+ z, ctx = torch.randn(2,32,16,16)*2.5, torch.randn(2,4,64)
159
+ torch.manual_seed(123)
160
+ l1 = flow_matching_loss(m1, z, ctx, num_iterations=3); l1["loss"].backward()
161
+ torch.manual_seed(123)
162
+ l2 = flow_matching_loss(m2, z, ctx, num_iterations=3); l2["loss"].backward()
163
+ diff = abs(l1["loss"].item() - l2["loss"].item())
164
+ maxg = max((p1.grad-p2.grad).abs().max().item() for (n1,p1),(n2,p2) in zip(m1.named_parameters(),m2.named_parameters()) if p1.grad is not None and p2.grad is not None)
165
+ assert diff < 1e-6 and maxg < 1e-4
166
+ print(f" Checkpointing: loss diff={diff:.8f}, max grad diff={maxg:.8f}")
167
+
168
+ def test_weight_sharing():
169
+ from iris.core import RefinementCore
170
+ c = RefinementCore(dim=64, num_blocks=3, num_heads=4, spatial_size=4, max_iterations=4, gradient_checkpointing=False)
171
+ x, ctx, t = torch.randn(1,16,64), torch.randn(1,4,64), torch.tensor([0.5])
172
+ o2 = c(x, ctx, t, 4, 4, num_iterations=2)
173
+ o4 = c(x, ctx, t, 4, 4, num_iterations=4)
174
+ assert not torch.allclose(o2, o4, atol=1e-3)
175
+ print(f" Weight sharing: {sum(p.numel() for p in c.parameters()):,} params (constant)")
176
+
177
+ def test_zero_init():
178
+ from iris.model import IRIS
179
+ m = IRIS(latent_channels=32, dim=64, patch_size=4, num_blocks=3, num_heads=4, gradient_checkpointing=False)
180
+ assert (m.unpatchify.proj.weight==0).all() and (m.unpatchify.proj.bias==0).all()
181
+ with torch.no_grad():
182
+ out = m(torch.randn(1,32,16,16), torch.tensor([0.5]), torch.randn(1,4,64), 1)
183
+ assert out.norm().item() < 1.0
184
+ print(f" Zero-init: output norm={out.norm().item():.6f}")
185
+
186
+ def test_training_stability():
187
+ from iris.model import IRIS
188
+ from iris.flow_matching import flow_matching_loss
189
+ m = IRIS(latent_channels=32, dim=128, patch_size=4, num_blocks=4, num_heads=4, max_iterations=4, gradient_checkpointing=False)
190
+ opt = torch.optim.AdamW(m.parameters(), lr=1e-3, weight_decay=0.01)
191
+ torch.manual_seed(0)
192
+ z, ctx = torch.randn(8,32,16,16)*2.5, torch.randn(8,8,128)
193
+ losses = []
194
+ m.train()
195
+ for s in range(100):
196
+ l = flow_matching_loss(m, z, ctx, num_iterations=2)
197
+ opt.zero_grad(set_to_none=True)
198
+ l["loss"].backward()
199
+ torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
200
+ opt.step()
201
+ losses.append(l["loss"].item())
202
+ if (s+1) % 25 == 0: print(f" Step {s+1}: loss={losses[-1]:.4f}")
203
+ f10, l10 = sum(losses[:10])/10, sum(losses[-10:])/10
204
+ print(f" Loss: {f10:.4f} -> {l10:.4f} ({(1-l10/f10)*100:.1f}%)")
205
+ assert l10 < f10 and not any(math.isnan(l) or math.isinf(l) for l in losses)
206
+
207
+ if __name__ == "__main__":
208
+ tests = [
209
+ ("SpectralConv2d", test_spectral_conv), ("TokenDifferential", test_token_differential),
210
+ ("PDESSMBlock", test_pde_ssm_block), ("CrossAttention (MQA)", test_cross_attention),
211
+ ("SelfAttention (MQA+2D RoPE)", test_self_attention), ("2D RoPE", test_rope_2d),
212
+ ("UIB-FFN", test_uib_ffn), ("TimestepEmbedding", test_timestep_embedding),
213
+ ("Patchify/Unpatchify", test_patchify_unpatchify), ("TinyDecoder", test_tiny_decoder),
214
+ ("IRIS Forward", test_iris_forward), ("Flow Matching Loss", test_flow_matching_loss),
215
+ ("Euler Sampling", test_euler_sampling), ("Gradient Checkpointing", test_gradient_checkpointing),
216
+ ("Weight Sharing", test_weight_sharing), ("Zero-Init Output", test_zero_init),
217
+ ("Training Stability (100 steps)", test_training_stability),
218
+ ]
219
+ passed = sum(1 for n,f in tests if test_module(n,f))
220
+ print(f"\n{'='*60}\nRESULTS: {passed}/{len(tests)} passed\n{'='*60}")
221
+ if passed < len(tests): sys.exit(1)