asdf98 commited on
Commit
a02e7fd
·
verified ·
1 Parent(s): 1c7e629

Add test_lira.py

Browse files
Files changed (1) hide show
  1. test_lira.py +403 -0
test_lira.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive test suite for LiRA architecture.
3
+ Tests: model creation, forward pass, memory footprint, gradient flow,
4
+ training step, and inference sampling.
5
+ """
6
+
7
+ import torch
8
+ import sys
9
+ import os
10
+ sys.path.insert(0, '/app')
11
+
12
+ from lira.model import LiRAModel, LiRAPipeline, TinyVAEDecoder, estimate_memory_mb
13
+ from lira.training import (
14
+ FlowMatchingScheduler, EMAModel, compute_loss,
15
+ LiRATrainingConfig, FlowDPMSolver
16
+ )
17
+
18
+
19
+ def test_model_creation():
20
+ """Test all model configurations can be instantiated"""
21
+ print("=" * 60)
22
+ print("TEST 1: Model Creation & Parameter Counts")
23
+ print("=" * 60)
24
+
25
+ configs = ['tiny', 'small', 'base']
26
+
27
+ for config_name in configs:
28
+ # Use SD1.x-style VAE params for testing (4ch, f8)
29
+ model = LiRAModel(
30
+ config_name=config_name,
31
+ in_channels=4,
32
+ d_text=768,
33
+ patch_size=2,
34
+ )
35
+
36
+ counts = model.count_parameters()
37
+ total_m = counts['total'] / 1e6
38
+
39
+ print(f"\nLiRA-{config_name.capitalize()}:")
40
+ print(f" Total parameters: {total_m:.1f}M")
41
+ for k, v in counts.items():
42
+ if k != 'total':
43
+ print(f" {k}: {v/1e6:.2f}M ({v/counts['total']*100:.1f}%)")
44
+
45
+ # Memory estimate for 1024px with f8 VAE
46
+ mem = estimate_memory_mb(model, batch_size=1, img_size=1024,
47
+ spatial_compression=8, latent_channels=4, dtype_bytes=2)
48
+ print(f" Estimated inference memory (fp16): {mem['total_inference_mb']:.0f}MB")
49
+ print(f" Params: {mem['params_mb']:.0f}MB, Latent: {mem['latent_mb']:.1f}MB, Activations: {mem['activation_mb']:.1f}MB")
50
+
51
+ # Also test f32 VAE configuration
52
+ print(f"\n--- f32 VAE Configuration (DC-AE) ---")
53
+ model_f32 = LiRAModel(
54
+ config_name='small',
55
+ in_channels=32,
56
+ d_text=768,
57
+ patch_size=1,
58
+ )
59
+ counts_f32 = model_f32.count_parameters()
60
+ mem_f32 = estimate_memory_mb(model_f32, batch_size=1, img_size=1024,
61
+ spatial_compression=32, latent_channels=32, dtype_bytes=2)
62
+ print(f" LiRA-Small (f32 VAE): {counts_f32['total']/1e6:.1f}M params")
63
+ print(f" Estimated inference memory (fp16): {mem_f32['total_inference_mb']:.0f}MB")
64
+ print(f" Latent tokens: {(1024//32)**2} (32x32)")
65
+
66
+ print("\n✅ All model configurations created successfully!")
67
+ return True
68
+
69
+
70
+ def test_forward_pass():
71
+ """Test forward pass with proper shapes"""
72
+ print("\n" + "=" * 60)
73
+ print("TEST 2: Forward Pass")
74
+ print("=" * 60)
75
+
76
+ model = LiRAModel(
77
+ config_name='tiny',
78
+ in_channels=4,
79
+ d_text=768,
80
+ patch_size=2,
81
+ )
82
+ model.eval()
83
+
84
+ # Simulate inputs
85
+ B = 2
86
+
87
+ # For 256px image with f8 VAE: 32x32 latent
88
+ z_t = torch.randn(B, 4, 32, 32)
89
+ t = torch.rand(B)
90
+ text_features = torch.randn(B, 77, 768) # CLIP-like
91
+ text_mask = torch.ones(B, 77, dtype=torch.bool)
92
+
93
+ print(f"Input shapes:")
94
+ print(f" z_t: {z_t.shape}")
95
+ print(f" t: {t.shape}")
96
+ print(f" text_features: {text_features.shape}")
97
+
98
+ with torch.no_grad():
99
+ v_pred, reason_info = model(z_t, t, text_features, text_mask)
100
+
101
+ print(f"\nOutput shapes:")
102
+ print(f" v_pred: {v_pred.shape}")
103
+ print(f" Reasoning steps: {reason_info['total_steps']}")
104
+ print(f" Discard rates: {[f'{r:.3f}' for r in reason_info['discard_rates']]}")
105
+ print(f" Stop values: {[f'{s:.3f}' for s in reason_info['stop_values']]}")
106
+
107
+ assert v_pred.shape == z_t.shape, f"Output shape mismatch: {v_pred.shape} vs {z_t.shape}"
108
+ print("\n✅ Forward pass successful!")
109
+ return True
110
+
111
+
112
+ def test_training_step():
113
+ """Test a complete training step with loss computation"""
114
+ print("\n" + "=" * 60)
115
+ print("TEST 3: Training Step")
116
+ print("=" * 60)
117
+
118
+ config = LiRATrainingConfig(
119
+ model_config='tiny',
120
+ latent_channels=4,
121
+ spatial_compression=8,
122
+ d_text=768,
123
+ patch_size=2,
124
+ batch_size=2,
125
+ learning_rate=1e-4,
126
+ )
127
+
128
+ model = LiRAModel(
129
+ config_name=config.model_config,
130
+ in_channels=config.latent_channels,
131
+ d_text=config.d_text,
132
+ patch_size=config.patch_size,
133
+ )
134
+ model.train()
135
+
136
+ optimizer = torch.optim.AdamW(
137
+ model.parameters(), lr=config.learning_rate,
138
+ weight_decay=config.weight_decay
139
+ )
140
+
141
+ scheduler = FlowMatchingScheduler(schedule=config.noise_schedule)
142
+ ema = EMAModel(model, decay=config.ema_decay)
143
+
144
+ # Simulate data
145
+ B = 2
146
+ z_0 = torch.randn(B, 4, 32, 32) # Latent from VAE
147
+ text_features = torch.randn(B, 77, 768)
148
+
149
+ # Training loop (3 steps)
150
+ print("Running 3 training steps...")
151
+ losses = []
152
+ for step in range(3):
153
+ optimizer.zero_grad()
154
+
155
+ loss, info = compute_loss(
156
+ model, z_0, text_features, scheduler, config,
157
+ global_step=step
158
+ )
159
+
160
+ loss.backward()
161
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
162
+ optimizer.step()
163
+ ema.update(model)
164
+
165
+ losses.append(info['loss'])
166
+ print(f" Step {step}: loss={info['loss']:.4f}, "
167
+ f"mse={info['mse_loss']:.4f}, "
168
+ f"reason_steps={info['reason_steps']}, "
169
+ f"grad_norm={grad_norm:.4f}")
170
+
171
+ # Verify loss is finite and reasonable
172
+ assert all(torch.isfinite(torch.tensor(l)) for l in losses), "Loss is not finite!"
173
+ assert all(l < 100 for l in losses), "Loss is unreasonably large!"
174
+
175
+ print("\n✅ Training step successful!")
176
+ return True
177
+
178
+
179
+ def test_gradient_flow():
180
+ """Verify gradients flow through all components"""
181
+ print("\n" + "=" * 60)
182
+ print("TEST 4: Gradient Flow Analysis")
183
+ print("=" * 60)
184
+
185
+ model = LiRAModel(
186
+ config_name='tiny',
187
+ in_channels=4,
188
+ d_text=768,
189
+ patch_size=2,
190
+ )
191
+ model.train()
192
+
193
+ z_t = torch.randn(1, 4, 32, 32)
194
+ t = torch.rand(1)
195
+ text = torch.randn(1, 77, 768)
196
+
197
+ v_pred, _ = model(z_t, t, text)
198
+ loss = v_pred.sum()
199
+ loss.backward()
200
+
201
+ # Check gradients in each component
202
+ components = {
203
+ 'patch_embed': model.patch_embed,
204
+ 'time_embed': model.time_embed,
205
+ 'text_proj': model.text_proj,
206
+ 'reasoning': model.reasoning,
207
+ 'blocks[0]': model.blocks[0],
208
+ 'blocks[-1]': model.blocks[-1],
209
+ }
210
+
211
+ for name, module in components.items():
212
+ has_grad = any(p.grad is not None and p.grad.abs().sum() > 0
213
+ for p in module.parameters() if p.requires_grad)
214
+ grad_norm = sum(p.grad.norm().item() for p in module.parameters()
215
+ if p.grad is not None)
216
+ status = "✅" if has_grad else "❌"
217
+ print(f" {status} {name}: grad_norm={grad_norm:.6f}")
218
+
219
+ print("\n✅ Gradient flow verified!")
220
+ return True
221
+
222
+
223
+ def test_sampling():
224
+ """Test inference sampling"""
225
+ print("\n" + "=" * 60)
226
+ print("TEST 5: Inference Sampling")
227
+ print("=" * 60)
228
+
229
+ model = LiRAModel(
230
+ config_name='tiny',
231
+ in_channels=4,
232
+ d_text=768,
233
+ patch_size=2,
234
+ )
235
+ model.eval()
236
+
237
+ solver = FlowDPMSolver(num_steps=5, order=2) # Few steps for testing
238
+
239
+ text_features = torch.randn(1, 77, 768)
240
+
241
+ print("Sampling with DPM-Solver (5 steps)...")
242
+ z_0 = solver.sample(
243
+ model,
244
+ shape=(1, 4, 32, 32),
245
+ text_features=text_features,
246
+ cfg_scale=1.0, # No CFG for speed
247
+ )
248
+
249
+ print(f" Output shape: {z_0.shape}")
250
+ print(f" Output range: [{z_0.min():.3f}, {z_0.max():.3f}]")
251
+ print(f" Output std: {z_0.std():.3f}")
252
+
253
+ assert z_0.shape == (1, 4, 32, 32), f"Wrong output shape: {z_0.shape}"
254
+ assert torch.isfinite(z_0).all(), "Output contains NaN/Inf!"
255
+
256
+ print("\n✅ Sampling successful!")
257
+ return True
258
+
259
+
260
+ def test_tiny_decoder():
261
+ """Test the mobile-optimized VAE decoder"""
262
+ print("\n" + "=" * 60)
263
+ print("TEST 6: Tiny VAE Decoder")
264
+ print("=" * 60)
265
+
266
+ # Test f8 decoder (128x128 → 1024x1024)
267
+ decoder_f8 = TinyVAEDecoder(
268
+ in_channels=4, spatial_compression=8, base_channels=64
269
+ )
270
+ params_f8 = sum(p.numel() for p in decoder_f8.parameters())
271
+
272
+ z = torch.randn(1, 4, 128, 128)
273
+ with torch.no_grad():
274
+ img = decoder_f8(z)
275
+
276
+ print(f"f8 Decoder:")
277
+ print(f" Parameters: {params_f8/1e6:.2f}M ({params_f8 * 2 / (1024**2):.1f}MB fp16)")
278
+ print(f" Input: {z.shape} → Output: {img.shape}")
279
+
280
+ # Test f32 decoder (32x32 → 1024x1024)
281
+ decoder_f32 = TinyVAEDecoder(
282
+ in_channels=32, spatial_compression=32, base_channels=64
283
+ )
284
+ params_f32 = sum(p.numel() for p in decoder_f32.parameters())
285
+
286
+ z32 = torch.randn(1, 32, 32, 32)
287
+ with torch.no_grad():
288
+ img32 = decoder_f32(z32)
289
+
290
+ print(f"\nf32 Decoder:")
291
+ print(f" Parameters: {params_f32/1e6:.2f}M ({params_f32 * 2 / (1024**2):.1f}MB fp16)")
292
+ print(f" Input: {z32.shape} → Output: {img32.shape}")
293
+
294
+ print("\n✅ Tiny VAE Decoder test passed!")
295
+ return True
296
+
297
+
298
+ def test_noise_schedules():
299
+ """Test all noise schedule variants"""
300
+ print("\n" + "=" * 60)
301
+ print("TEST 7: Noise Schedules")
302
+ print("=" * 60)
303
+
304
+ for schedule in ['laplace', 'logit_normal', 'uniform']:
305
+ scheduler = FlowMatchingScheduler(schedule=schedule)
306
+ t = scheduler.sample_timesteps(10000, torch.device('cpu'))
307
+
308
+ print(f"\n{schedule}:")
309
+ print(f" Mean: {t.mean():.3f}, Std: {t.std():.3f}")
310
+ print(f" Min: {t.min():.3f}, Max: {t.max():.3f}")
311
+
312
+ # Check distribution shape
313
+ bins = torch.histc(t, bins=10, min=0, max=1)
314
+ bins = bins / bins.sum()
315
+ print(f" Distribution (10 bins): {[f'{b:.2f}' for b in bins.tolist()]}")
316
+
317
+ print("\n✅ All noise schedules working!")
318
+ return True
319
+
320
+
321
+ def test_full_pipeline():
322
+ """Test the complete pipeline including parameter summary"""
323
+ print("\n" + "=" * 60)
324
+ print("TEST 8: Full Pipeline Summary")
325
+ print("=" * 60)
326
+
327
+ pipeline = LiRAPipeline(
328
+ config_name='small',
329
+ latent_channels=32,
330
+ spatial_compression=32,
331
+ d_text=768,
332
+ patch_size=1,
333
+ )
334
+
335
+ counts = pipeline.count_parameters()
336
+
337
+ print("\n🏗️ LiRA-Small Pipeline (f32 VAE, 1024px native):")
338
+ print(f" Denoiser: {counts['total']/1e6:.1f}M params")
339
+ print(f" Tiny Decoder: {counts['tiny_decoder']/1e6:.2f}M params")
340
+ print(f" Total: {counts['total_with_decoder']/1e6:.1f}M params")
341
+ print(f" Model size (fp16): {counts['total_with_decoder'] * 2 / (1024**2):.0f}MB")
342
+
343
+ # Breakdown
344
+ print(f"\n Component breakdown:")
345
+ for k, v in counts.items():
346
+ if k not in ['total', 'total_with_decoder', 'tiny_decoder']:
347
+ print(f" {k}: {v/1e6:.2f}M ({v/counts['total']*100:.1f}%)")
348
+
349
+ # Memory estimate
350
+ mem = estimate_memory_mb(pipeline, batch_size=1, img_size=1024,
351
+ spatial_compression=32, latent_channels=32, dtype_bytes=2)
352
+ print(f"\n 💾 Estimated inference memory:")
353
+ print(f" Model params: {mem['params_mb']:.0f}MB")
354
+ print(f" Latent tensors: {mem['latent_mb']:.1f}MB")
355
+ print(f" Activations: {mem['activation_mb']:.1f}MB")
356
+ print(f" Total: {mem['total_inference_mb']:.0f}MB")
357
+
358
+ # Latent token analysis
359
+ lat_h = 1024 // 32
360
+ lat_w = 1024 // 32
361
+ print(f"\n 📐 Latent space:")
362
+ print(f" Image: 1024x1024px → Latent: {lat_h}x{lat_w} = {lat_h*lat_w} tokens")
363
+ print(f" Complexity: O({lat_h*lat_w}) per block (linear, not quadratic)")
364
+ print(f" Equivalent quadratic cost: O({lat_h*lat_w}²) = O({(lat_h*lat_w)**2:,})")
365
+
366
+ print("\n✅ Full pipeline test passed!")
367
+ return True
368
+
369
+
370
+ if __name__ == '__main__':
371
+ print("🎨 LiRA (Liquid Reasoning Artisan) - Architecture Tests")
372
+ print("=" * 60)
373
+
374
+ tests = [
375
+ test_model_creation,
376
+ test_forward_pass,
377
+ test_training_step,
378
+ test_gradient_flow,
379
+ test_sampling,
380
+ test_tiny_decoder,
381
+ test_noise_schedules,
382
+ test_full_pipeline,
383
+ ]
384
+
385
+ passed = 0
386
+ failed = 0
387
+
388
+ for test_fn in tests:
389
+ try:
390
+ result = test_fn()
391
+ if result:
392
+ passed += 1
393
+ else:
394
+ failed += 1
395
+ except Exception as e:
396
+ print(f"\n❌ {test_fn.__name__} FAILED: {e}")
397
+ import traceback
398
+ traceback.print_exc()
399
+ failed += 1
400
+
401
+ print("\n" + "=" * 60)
402
+ print(f"RESULTS: {passed} passed, {failed} failed out of {len(tests)} tests")
403
+ print("=" * 60)