asdf98 commited on
Commit
9bfb518
·
verified ·
1 Parent(s): 2c6f96a

Add test_microforge.py

Browse files
Files changed (1) hide show
  1. test_microforge.py +254 -0
test_microforge.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MicroForge End-to-End Test Suite
4
+ Validates all modules work correctly on CPU.
5
+ """
6
+
7
+ import torch
8
+ import time
9
+ import sys
10
+ import os
11
+
12
+ # Add parent to path
13
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
+
15
+
16
+ def test_vae():
17
+ """Test all VAE configurations."""
18
+ from microforge.vae import MicroForgeVAE
19
+
20
+ print("=" * 60)
21
+ print("TEST: MicroForge VAE")
22
+ print("=" * 60)
23
+
24
+ for config in ['tiny', 'small', 'base']:
25
+ vae = MicroForgeVAE(config=config)
26
+ params = sum(p.numel() for p in vae.parameters())
27
+
28
+ # Test forward pass
29
+ x = torch.randn(1, 3, 256, 256)
30
+ x_recon, mu, logvar = vae(x)
31
+
32
+ assert x_recon.shape == x.shape, f"Recon shape mismatch: {x_recon.shape} vs {x.shape}"
33
+ assert not torch.isnan(mu).any(), "NaN in mu"
34
+ assert not torch.isnan(logvar).any(), "NaN in logvar"
35
+
36
+ # Test encode/decode
37
+ z = vae.get_latent(x)
38
+ x_dec = vae.decode(z)
39
+ assert x_dec.shape == x.shape
40
+
41
+ # Test KL loss
42
+ kl = MicroForgeVAE.kl_loss(mu, logvar)
43
+ assert not torch.isnan(kl), "NaN in KL loss"
44
+
45
+ print(f" [{config:>5}] PASS | params={params:,} | latent={mu.shape} | KL={kl.item():.2f}")
46
+
47
+ print()
48
+
49
+
50
+ def test_backbone():
51
+ """Test all backbone configurations."""
52
+ from microforge.backbone import MicroForgeBackbone
53
+
54
+ print("=" * 60)
55
+ print("TEST: MicroForge Backbone")
56
+ print("=" * 60)
57
+
58
+ for config in ['tiny', 'small', 'base']:
59
+ lc = 16 if config == 'tiny' else 32
60
+ backbone = MicroForgeBackbone(latent_channels=lc, config=config)
61
+ params = sum(p.numel() for p in backbone.parameters())
62
+
63
+ z = torch.randn(1, lc, 8, 8)
64
+ t = torch.rand(1)
65
+ text_emb = torch.randn(1, 10, 768)
66
+ text_pooled = torch.randn(1, 768)
67
+
68
+ start = time.time()
69
+ v = backbone(z, t, text_emb, text_pooled)
70
+ elapsed = (time.time() - start) * 1000
71
+
72
+ assert v.shape == z.shape, f"Output shape mismatch: {v.shape} vs {z.shape}"
73
+ assert not torch.isnan(v).any(), "NaN in velocity prediction"
74
+
75
+ print(f" [{config:>5}] PASS | params={params:,} | latency={elapsed:.0f}ms")
76
+
77
+ print()
78
+
79
+
80
+ def test_planner():
81
+ """Test Recurrent Latent Planner."""
82
+ from microforge.planner import RecurrentLatentPlanner
83
+
84
+ print("=" * 60)
85
+ print("TEST: Recurrent Latent Planner")
86
+ print("=" * 60)
87
+
88
+ planner = RecurrentLatentPlanner(
89
+ num_plan_tokens=32, dim=384, text_dim=768, latent_channels=32
90
+ )
91
+ params = sum(p.numel() for p in planner.parameters())
92
+
93
+ # Test initialization
94
+ text_pooled = torch.randn(2, 768)
95
+ plan = planner.initialize_plan(text_pooled, batch_size=2)
96
+ assert plan.shape == (2, 32, 384), f"Plan shape: {plan.shape}"
97
+
98
+ # Test forward
99
+ img_tokens = torch.randn(2, 64, 32) # 8x8 latent flattened
100
+ t_emb = torch.randn(2, 384)
101
+ plan_out, output = planner(img_tokens, plan, t_emb)
102
+
103
+ assert plan_out.shape == (2, 32, 384)
104
+ assert output.shape == (2, 32, 768) # Projected to text_dim
105
+ assert not torch.isnan(plan_out).any()
106
+ assert not torch.isnan(output).any()
107
+
108
+ # Test self-conditioning
109
+ plan_next = planner.initialize_plan(text_pooled, 2, prev_plan=plan_out)
110
+ assert plan_next.shape == plan.shape
111
+
112
+ print(f" PASS | params={params:,} | plan_state={planner.get_plan_size_bytes()} bytes")
113
+ print()
114
+
115
+
116
+ def test_training():
117
+ """Test training loop."""
118
+ from microforge.vae import MicroForgeVAE
119
+ from microforge.backbone import MicroForgeBackbone
120
+ from microforge.planner import RecurrentLatentPlanner
121
+ from microforge.training import MicroForgeTrainer, FlowMatchingScheduler
122
+
123
+ print("=" * 60)
124
+ print("TEST: Training Pipeline")
125
+ print("=" * 60)
126
+
127
+ vae = MicroForgeVAE(config='tiny').eval()
128
+ backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
129
+ planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
130
+
131
+ trainer = MicroForgeTrainer(vae, backbone, planner, lr=1e-4, use_ema=True)
132
+
133
+ # Test flow matching scheduler
134
+ scheduler = FlowMatchingScheduler()
135
+ t = scheduler.sample_timesteps(4, torch.device('cpu'))
136
+ assert t.min() >= 0 and t.max() <= 1, f"Timesteps out of range: {t}"
137
+
138
+ z_0 = torch.randn(4, 16, 4, 4)
139
+ noise = torch.randn_like(z_0)
140
+ z_t, v_target = scheduler.add_noise(z_0, noise, t)
141
+ assert z_t.shape == z_0.shape
142
+ assert v_target.shape == z_0.shape
143
+
144
+ # Test training steps
145
+ images = torch.randn(2, 3, 128, 128)
146
+ text_emb = torch.randn(2, 10, 768)
147
+ text_pooled = torch.randn(2, 768)
148
+
149
+ losses = []
150
+ for i in range(5):
151
+ step_losses = trainer.train_step(images, text_emb, text_pooled)
152
+ losses.append(step_losses['flow'])
153
+ assert not any(torch.isnan(torch.tensor(v)) for v in step_losses.values()), \
154
+ f"NaN in losses: {step_losses}"
155
+
156
+ print(f" 5 training steps: loss {losses[0]:.2f} -> {losses[-1]:.2f}")
157
+ print(f" PASS")
158
+ print()
159
+
160
+
161
+ def test_pipeline():
162
+ """Test end-to-end inference pipeline."""
163
+ from microforge.vae import MicroForgeVAE
164
+ from microforge.backbone import MicroForgeBackbone
165
+ from microforge.planner import RecurrentLatentPlanner
166
+ from microforge.pipeline import MicroForgePipeline, SimpleTextEncoder
167
+
168
+ print("=" * 60)
169
+ print("TEST: End-to-End Pipeline")
170
+ print("=" * 60)
171
+
172
+ vae = MicroForgeVAE(config='tiny')
173
+ backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
174
+ planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
175
+ text_enc = SimpleTextEncoder(embed_dim=768, num_layers=2)
176
+
177
+ pipeline = MicroForgePipeline(vae, backbone, text_enc, planner, device='cpu')
178
+
179
+ # Test text2img
180
+ tokens = torch.randint(0, 8192, (1, 10))
181
+ start = time.time()
182
+ images = pipeline.text2img(tokens, height=128, width=128, num_steps=2, cfg_scale=1.0, seed=42)
183
+ t2i_time = time.time() - start
184
+
185
+ assert images.shape == (1, 3, 128, 128), f"Wrong output shape: {images.shape}"
186
+ assert images.min() >= -1 and images.max() <= 1, f"Range error: [{images.min()}, {images.max()}]"
187
+
188
+ print(f" text2img: {images.shape} in {t2i_time:.2f}s | PASS")
189
+
190
+ # Test parameter count
191
+ params = pipeline.count_parameters()
192
+ print(f" Total params: {params['total']:,}")
193
+
194
+ # Test memory estimate
195
+ mem = pipeline.get_memory_estimate(512, 512)
196
+ print(f" Est. memory @512px: {mem['estimated_inference_mb']:.0f} MB")
197
+
198
+ print(f" PASS")
199
+ print()
200
+
201
+
202
+ def test_editing_pathway():
203
+ """Test that editing pathway works (spatial concat)."""
204
+ from microforge.backbone import MicroForgeBackbone
205
+
206
+ print("=" * 60)
207
+ print("TEST: Editing Pathway (Spatial Concat)")
208
+ print("=" * 60)
209
+
210
+ backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
211
+
212
+ # Standard generation: 8x8 latent
213
+ z_gen = torch.randn(1, 16, 8, 8)
214
+ t = torch.rand(1)
215
+ text_emb = torch.randn(1, 5, 768)
216
+ text_pooled = torch.randn(1, 768)
217
+
218
+ v_gen = backbone(z_gen, t, text_emb, text_pooled)
219
+ assert v_gen.shape == z_gen.shape, f"Gen output shape: {v_gen.shape}"
220
+
221
+ # Editing: 8x16 latent (width-concat target + source)
222
+ z_edit = torch.randn(1, 16, 8, 16) # Doubled width
223
+ v_edit = backbone(z_edit, t, text_emb, text_pooled)
224
+ assert v_edit.shape == z_edit.shape, f"Edit output shape: {v_edit.shape}"
225
+
226
+ # Extract target velocity (left half)
227
+ v_target = v_edit[..., :8]
228
+ assert v_target.shape == z_gen.shape
229
+
230
+ print(f" Generation: {z_gen.shape} -> {v_gen.shape} | PASS")
231
+ print(f" Editing: {z_edit.shape} -> {v_edit.shape} | PASS")
232
+ print()
233
+
234
+
235
+ def main():
236
+ print()
237
+ print("🔨 MicroForge Architecture Test Suite")
238
+ print("=" * 60)
239
+ print()
240
+
241
+ test_vae()
242
+ test_backbone()
243
+ test_planner()
244
+ test_training()
245
+ test_pipeline()
246
+ test_editing_pathway()
247
+
248
+ print("=" * 60)
249
+ print("✅ ALL TESTS PASSED")
250
+ print("=" * 60)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()