| """ |
| Comprehensive test suite for LiRA architecture. |
| Tests: model creation, forward pass, memory footprint, gradient flow, |
| training step, and inference sampling. |
| """ |
|
|
| import torch |
| import sys |
| import os |
| sys.path.insert(0, '/app') |
|
|
| from lira.model import LiRAModel, LiRAPipeline, TinyVAEDecoder, estimate_memory_mb |
| from lira.training import ( |
| FlowMatchingScheduler, EMAModel, compute_loss, |
| LiRATrainingConfig, FlowDPMSolver |
| ) |
|
|
|
|
| def test_model_creation(): |
| """Test all model configurations can be instantiated""" |
| print("=" * 60) |
| print("TEST 1: Model Creation & Parameter Counts") |
| print("=" * 60) |
| |
| configs = ['tiny', 'small', 'base'] |
| |
| for config_name in configs: |
| |
| model = LiRAModel( |
| config_name=config_name, |
| in_channels=4, |
| d_text=768, |
| patch_size=2, |
| ) |
| |
| counts = model.count_parameters() |
| total_m = counts['total'] / 1e6 |
| |
| print(f"\nLiRA-{config_name.capitalize()}:") |
| print(f" Total parameters: {total_m:.1f}M") |
| for k, v in counts.items(): |
| if k != 'total': |
| print(f" {k}: {v/1e6:.2f}M ({v/counts['total']*100:.1f}%)") |
| |
| |
| mem = estimate_memory_mb(model, batch_size=1, img_size=1024, |
| spatial_compression=8, latent_channels=4, dtype_bytes=2) |
| print(f" Estimated inference memory (fp16): {mem['total_inference_mb']:.0f}MB") |
| print(f" Params: {mem['params_mb']:.0f}MB, Latent: {mem['latent_mb']:.1f}MB, Activations: {mem['activation_mb']:.1f}MB") |
| |
| |
| print(f"\n--- f32 VAE Configuration (DC-AE) ---") |
| model_f32 = LiRAModel( |
| config_name='small', |
| in_channels=32, |
| d_text=768, |
| patch_size=1, |
| ) |
| counts_f32 = model_f32.count_parameters() |
| mem_f32 = estimate_memory_mb(model_f32, batch_size=1, img_size=1024, |
| spatial_compression=32, latent_channels=32, dtype_bytes=2) |
| print(f" LiRA-Small (f32 VAE): {counts_f32['total']/1e6:.1f}M params") |
| print(f" Estimated inference memory (fp16): {mem_f32['total_inference_mb']:.0f}MB") |
| print(f" Latent tokens: {(1024//32)**2} (32x32)") |
| |
| print("\n✅ All model configurations created successfully!") |
| return True |
|
|
|
|
| def test_forward_pass(): |
| """Test forward pass with proper shapes""" |
| print("\n" + "=" * 60) |
| print("TEST 2: Forward Pass") |
| print("=" * 60) |
| |
| model = LiRAModel( |
| config_name='tiny', |
| in_channels=4, |
| d_text=768, |
| patch_size=2, |
| ) |
| model.eval() |
| |
| |
| B = 2 |
| |
| |
| z_t = torch.randn(B, 4, 32, 32) |
| t = torch.rand(B) |
| text_features = torch.randn(B, 77, 768) |
| text_mask = torch.ones(B, 77, dtype=torch.bool) |
| |
| print(f"Input shapes:") |
| print(f" z_t: {z_t.shape}") |
| print(f" t: {t.shape}") |
| print(f" text_features: {text_features.shape}") |
| |
| with torch.no_grad(): |
| v_pred, reason_info = model(z_t, t, text_features, text_mask) |
| |
| print(f"\nOutput shapes:") |
| print(f" v_pred: {v_pred.shape}") |
| print(f" Reasoning steps: {reason_info['total_steps']}") |
| print(f" Discard rates: {[f'{r:.3f}' for r in reason_info['discard_rates']]}") |
| print(f" Stop values: {[f'{s:.3f}' for s in reason_info['stop_values']]}") |
| |
| assert v_pred.shape == z_t.shape, f"Output shape mismatch: {v_pred.shape} vs {z_t.shape}" |
| print("\n✅ Forward pass successful!") |
| return True |
|
|
|
|
| def test_training_step(): |
| """Test a complete training step with loss computation""" |
| print("\n" + "=" * 60) |
| print("TEST 3: Training Step") |
| print("=" * 60) |
| |
| config = LiRATrainingConfig( |
| model_config='tiny', |
| latent_channels=4, |
| spatial_compression=8, |
| d_text=768, |
| patch_size=2, |
| batch_size=2, |
| learning_rate=1e-4, |
| ) |
| |
| model = LiRAModel( |
| config_name=config.model_config, |
| in_channels=config.latent_channels, |
| d_text=config.d_text, |
| patch_size=config.patch_size, |
| ) |
| model.train() |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), lr=config.learning_rate, |
| weight_decay=config.weight_decay |
| ) |
| |
| scheduler = FlowMatchingScheduler(schedule=config.noise_schedule) |
| ema = EMAModel(model, decay=config.ema_decay) |
| |
| |
| B = 2 |
| z_0 = torch.randn(B, 4, 32, 32) |
| text_features = torch.randn(B, 77, 768) |
| |
| |
| print("Running 3 training steps...") |
| losses = [] |
| for step in range(3): |
| optimizer.zero_grad() |
| |
| loss, info = compute_loss( |
| model, z_0, text_features, scheduler, config, |
| global_step=step |
| ) |
| |
| loss.backward() |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) |
| optimizer.step() |
| ema.update(model) |
| |
| losses.append(info['loss']) |
| print(f" Step {step}: loss={info['loss']:.4f}, " |
| f"mse={info['mse_loss']:.4f}, " |
| f"reason_steps={info['reason_steps']}, " |
| f"grad_norm={grad_norm:.4f}") |
| |
| |
| assert all(torch.isfinite(torch.tensor(l)) for l in losses), "Loss is not finite!" |
| assert all(l < 100 for l in losses), "Loss is unreasonably large!" |
| |
| print("\n✅ Training step successful!") |
| return True |
|
|
|
|
| def test_gradient_flow(): |
| """Verify gradients flow through all components""" |
| print("\n" + "=" * 60) |
| print("TEST 4: Gradient Flow Analysis") |
| print("=" * 60) |
| |
| model = LiRAModel( |
| config_name='tiny', |
| in_channels=4, |
| d_text=768, |
| patch_size=2, |
| ) |
| model.train() |
| |
| z_t = torch.randn(1, 4, 32, 32) |
| t = torch.rand(1) |
| text = torch.randn(1, 77, 768) |
| |
| v_pred, _ = model(z_t, t, text) |
| loss = v_pred.sum() |
| loss.backward() |
| |
| |
| components = { |
| 'patch_embed': model.patch_embed, |
| 'time_embed': model.time_embed, |
| 'text_proj': model.text_proj, |
| 'reasoning': model.reasoning, |
| 'blocks[0]': model.blocks[0], |
| 'blocks[-1]': model.blocks[-1], |
| } |
| |
| for name, module in components.items(): |
| has_grad = any(p.grad is not None and p.grad.abs().sum() > 0 |
| for p in module.parameters() if p.requires_grad) |
| grad_norm = sum(p.grad.norm().item() for p in module.parameters() |
| if p.grad is not None) |
| status = "✅" if has_grad else "❌" |
| print(f" {status} {name}: grad_norm={grad_norm:.6f}") |
| |
| print("\n✅ Gradient flow verified!") |
| return True |
|
|
|
|
| def test_sampling(): |
| """Test inference sampling""" |
| print("\n" + "=" * 60) |
| print("TEST 5: Inference Sampling") |
| print("=" * 60) |
| |
| model = LiRAModel( |
| config_name='tiny', |
| in_channels=4, |
| d_text=768, |
| patch_size=2, |
| ) |
| model.eval() |
| |
| solver = FlowDPMSolver(num_steps=5, order=2) |
| |
| text_features = torch.randn(1, 77, 768) |
| |
| print("Sampling with DPM-Solver (5 steps)...") |
| z_0 = solver.sample( |
| model, |
| shape=(1, 4, 32, 32), |
| text_features=text_features, |
| cfg_scale=1.0, |
| ) |
| |
| print(f" Output shape: {z_0.shape}") |
| print(f" Output range: [{z_0.min():.3f}, {z_0.max():.3f}]") |
| print(f" Output std: {z_0.std():.3f}") |
| |
| assert z_0.shape == (1, 4, 32, 32), f"Wrong output shape: {z_0.shape}" |
| assert torch.isfinite(z_0).all(), "Output contains NaN/Inf!" |
| |
| print("\n✅ Sampling successful!") |
| return True |
|
|
|
|
| def test_tiny_decoder(): |
| """Test the mobile-optimized VAE decoder""" |
| print("\n" + "=" * 60) |
| print("TEST 6: Tiny VAE Decoder") |
| print("=" * 60) |
| |
| |
| decoder_f8 = TinyVAEDecoder( |
| in_channels=4, spatial_compression=8, base_channels=64 |
| ) |
| params_f8 = sum(p.numel() for p in decoder_f8.parameters()) |
| |
| z = torch.randn(1, 4, 128, 128) |
| with torch.no_grad(): |
| img = decoder_f8(z) |
| |
| print(f"f8 Decoder:") |
| print(f" Parameters: {params_f8/1e6:.2f}M ({params_f8 * 2 / (1024**2):.1f}MB fp16)") |
| print(f" Input: {z.shape} → Output: {img.shape}") |
| |
| |
| decoder_f32 = TinyVAEDecoder( |
| in_channels=32, spatial_compression=32, base_channels=64 |
| ) |
| params_f32 = sum(p.numel() for p in decoder_f32.parameters()) |
| |
| z32 = torch.randn(1, 32, 32, 32) |
| with torch.no_grad(): |
| img32 = decoder_f32(z32) |
| |
| print(f"\nf32 Decoder:") |
| print(f" Parameters: {params_f32/1e6:.2f}M ({params_f32 * 2 / (1024**2):.1f}MB fp16)") |
| print(f" Input: {z32.shape} → Output: {img32.shape}") |
| |
| print("\n✅ Tiny VAE Decoder test passed!") |
| return True |
|
|
|
|
| def test_noise_schedules(): |
| """Test all noise schedule variants""" |
| print("\n" + "=" * 60) |
| print("TEST 7: Noise Schedules") |
| print("=" * 60) |
| |
| for schedule in ['laplace', 'logit_normal', 'uniform']: |
| scheduler = FlowMatchingScheduler(schedule=schedule) |
| t = scheduler.sample_timesteps(10000, torch.device('cpu')) |
| |
| print(f"\n{schedule}:") |
| print(f" Mean: {t.mean():.3f}, Std: {t.std():.3f}") |
| print(f" Min: {t.min():.3f}, Max: {t.max():.3f}") |
| |
| |
| bins = torch.histc(t, bins=10, min=0, max=1) |
| bins = bins / bins.sum() |
| print(f" Distribution (10 bins): {[f'{b:.2f}' for b in bins.tolist()]}") |
| |
| print("\n✅ All noise schedules working!") |
| return True |
|
|
|
|
| def test_full_pipeline(): |
| """Test the complete pipeline including parameter summary""" |
| print("\n" + "=" * 60) |
| print("TEST 8: Full Pipeline Summary") |
| print("=" * 60) |
| |
| pipeline = LiRAPipeline( |
| config_name='small', |
| latent_channels=32, |
| spatial_compression=32, |
| d_text=768, |
| patch_size=1, |
| ) |
| |
| counts = pipeline.count_parameters() |
| |
| print("\n🏗️ LiRA-Small Pipeline (f32 VAE, 1024px native):") |
| print(f" Denoiser: {counts['total']/1e6:.1f}M params") |
| print(f" Tiny Decoder: {counts['tiny_decoder']/1e6:.2f}M params") |
| print(f" Total: {counts['total_with_decoder']/1e6:.1f}M params") |
| print(f" Model size (fp16): {counts['total_with_decoder'] * 2 / (1024**2):.0f}MB") |
| |
| |
| print(f"\n Component breakdown:") |
| for k, v in counts.items(): |
| if k not in ['total', 'total_with_decoder', 'tiny_decoder']: |
| print(f" {k}: {v/1e6:.2f}M ({v/counts['total']*100:.1f}%)") |
| |
| |
| mem = estimate_memory_mb(pipeline, batch_size=1, img_size=1024, |
| spatial_compression=32, latent_channels=32, dtype_bytes=2) |
| print(f"\n 💾 Estimated inference memory:") |
| print(f" Model params: {mem['params_mb']:.0f}MB") |
| print(f" Latent tensors: {mem['latent_mb']:.1f}MB") |
| print(f" Activations: {mem['activation_mb']:.1f}MB") |
| print(f" Total: {mem['total_inference_mb']:.0f}MB") |
| |
| |
| lat_h = 1024 // 32 |
| lat_w = 1024 // 32 |
| print(f"\n 📐 Latent space:") |
| print(f" Image: 1024x1024px → Latent: {lat_h}x{lat_w} = {lat_h*lat_w} tokens") |
| print(f" Complexity: O({lat_h*lat_w}) per block (linear, not quadratic)") |
| print(f" Equivalent quadratic cost: O({lat_h*lat_w}²) = O({(lat_h*lat_w)**2:,})") |
| |
| print("\n✅ Full pipeline test passed!") |
| return True |
|
|
|
|
| if __name__ == '__main__': |
| print("🎨 LiRA (Liquid Reasoning Artisan) - Architecture Tests") |
| print("=" * 60) |
| |
| tests = [ |
| test_model_creation, |
| test_forward_pass, |
| test_training_step, |
| test_gradient_flow, |
| test_sampling, |
| test_tiny_decoder, |
| test_noise_schedules, |
| test_full_pipeline, |
| ] |
| |
| passed = 0 |
| failed = 0 |
| |
| for test_fn in tests: |
| try: |
| result = test_fn() |
| if result: |
| passed += 1 |
| else: |
| failed += 1 |
| except Exception as e: |
| print(f"\n❌ {test_fn.__name__} FAILED: {e}") |
| import traceback |
| traceback.print_exc() |
| failed += 1 |
| |
| print("\n" + "=" * 60) |
| print(f"RESULTS: {passed} passed, {failed} failed out of {len(tests)} tests") |
| print("=" * 60) |
|
|