""" Comprehensive test suite for LiRA architecture. Tests: model creation, forward pass, memory footprint, gradient flow, training step, and inference sampling. """ import torch import sys import os sys.path.insert(0, '/app') from lira.model import LiRAModel, LiRAPipeline, TinyVAEDecoder, estimate_memory_mb from lira.training import ( FlowMatchingScheduler, EMAModel, compute_loss, LiRATrainingConfig, FlowDPMSolver ) def test_model_creation(): """Test all model configurations can be instantiated""" print("=" * 60) print("TEST 1: Model Creation & Parameter Counts") print("=" * 60) configs = ['tiny', 'small', 'base'] for config_name in configs: # Use SD1.x-style VAE params for testing (4ch, f8) model = LiRAModel( config_name=config_name, in_channels=4, d_text=768, patch_size=2, ) counts = model.count_parameters() total_m = counts['total'] / 1e6 print(f"\nLiRA-{config_name.capitalize()}:") print(f" Total parameters: {total_m:.1f}M") for k, v in counts.items(): if k != 'total': print(f" {k}: {v/1e6:.2f}M ({v/counts['total']*100:.1f}%)") # Memory estimate for 1024px with f8 VAE mem = estimate_memory_mb(model, batch_size=1, img_size=1024, spatial_compression=8, latent_channels=4, dtype_bytes=2) print(f" Estimated inference memory (fp16): {mem['total_inference_mb']:.0f}MB") print(f" Params: {mem['params_mb']:.0f}MB, Latent: {mem['latent_mb']:.1f}MB, Activations: {mem['activation_mb']:.1f}MB") # Also test f32 VAE configuration print(f"\n--- f32 VAE Configuration (DC-AE) ---") model_f32 = LiRAModel( config_name='small', in_channels=32, d_text=768, patch_size=1, ) counts_f32 = model_f32.count_parameters() mem_f32 = estimate_memory_mb(model_f32, batch_size=1, img_size=1024, spatial_compression=32, latent_channels=32, dtype_bytes=2) print(f" LiRA-Small (f32 VAE): {counts_f32['total']/1e6:.1f}M params") print(f" Estimated inference memory (fp16): {mem_f32['total_inference_mb']:.0f}MB") print(f" Latent tokens: {(1024//32)**2} (32x32)") print("\n✅ All model configurations created successfully!") return True def test_forward_pass(): """Test forward pass with proper shapes""" print("\n" + "=" * 60) print("TEST 2: Forward Pass") print("=" * 60) model = LiRAModel( config_name='tiny', in_channels=4, d_text=768, patch_size=2, ) model.eval() # Simulate inputs B = 2 # For 256px image with f8 VAE: 32x32 latent z_t = torch.randn(B, 4, 32, 32) t = torch.rand(B) text_features = torch.randn(B, 77, 768) # CLIP-like text_mask = torch.ones(B, 77, dtype=torch.bool) print(f"Input shapes:") print(f" z_t: {z_t.shape}") print(f" t: {t.shape}") print(f" text_features: {text_features.shape}") with torch.no_grad(): v_pred, reason_info = model(z_t, t, text_features, text_mask) print(f"\nOutput shapes:") print(f" v_pred: {v_pred.shape}") print(f" Reasoning steps: {reason_info['total_steps']}") print(f" Discard rates: {[f'{r:.3f}' for r in reason_info['discard_rates']]}") print(f" Stop values: {[f'{s:.3f}' for s in reason_info['stop_values']]}") assert v_pred.shape == z_t.shape, f"Output shape mismatch: {v_pred.shape} vs {z_t.shape}" print("\n✅ Forward pass successful!") return True def test_training_step(): """Test a complete training step with loss computation""" print("\n" + "=" * 60) print("TEST 3: Training Step") print("=" * 60) config = LiRATrainingConfig( model_config='tiny', latent_channels=4, spatial_compression=8, d_text=768, patch_size=2, batch_size=2, learning_rate=1e-4, ) model = LiRAModel( config_name=config.model_config, in_channels=config.latent_channels, d_text=config.d_text, patch_size=config.patch_size, ) model.train() optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay ) scheduler = FlowMatchingScheduler(schedule=config.noise_schedule) ema = EMAModel(model, decay=config.ema_decay) # Simulate data B = 2 z_0 = torch.randn(B, 4, 32, 32) # Latent from VAE text_features = torch.randn(B, 77, 768) # Training loop (3 steps) print("Running 3 training steps...") losses = [] for step in range(3): optimizer.zero_grad() loss, info = compute_loss( model, z_0, text_features, scheduler, config, global_step=step ) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) optimizer.step() ema.update(model) losses.append(info['loss']) print(f" Step {step}: loss={info['loss']:.4f}, " f"mse={info['mse_loss']:.4f}, " f"reason_steps={info['reason_steps']}, " f"grad_norm={grad_norm:.4f}") # Verify loss is finite and reasonable assert all(torch.isfinite(torch.tensor(l)) for l in losses), "Loss is not finite!" assert all(l < 100 for l in losses), "Loss is unreasonably large!" print("\n✅ Training step successful!") return True def test_gradient_flow(): """Verify gradients flow through all components""" print("\n" + "=" * 60) print("TEST 4: Gradient Flow Analysis") print("=" * 60) model = LiRAModel( config_name='tiny', in_channels=4, d_text=768, patch_size=2, ) model.train() z_t = torch.randn(1, 4, 32, 32) t = torch.rand(1) text = torch.randn(1, 77, 768) v_pred, _ = model(z_t, t, text) loss = v_pred.sum() loss.backward() # Check gradients in each component components = { 'patch_embed': model.patch_embed, 'time_embed': model.time_embed, 'text_proj': model.text_proj, 'reasoning': model.reasoning, 'blocks[0]': model.blocks[0], 'blocks[-1]': model.blocks[-1], } for name, module in components.items(): has_grad = any(p.grad is not None and p.grad.abs().sum() > 0 for p in module.parameters() if p.requires_grad) grad_norm = sum(p.grad.norm().item() for p in module.parameters() if p.grad is not None) status = "✅" if has_grad else "❌" print(f" {status} {name}: grad_norm={grad_norm:.6f}") print("\n✅ Gradient flow verified!") return True def test_sampling(): """Test inference sampling""" print("\n" + "=" * 60) print("TEST 5: Inference Sampling") print("=" * 60) model = LiRAModel( config_name='tiny', in_channels=4, d_text=768, patch_size=2, ) model.eval() solver = FlowDPMSolver(num_steps=5, order=2) # Few steps for testing text_features = torch.randn(1, 77, 768) print("Sampling with DPM-Solver (5 steps)...") z_0 = solver.sample( model, shape=(1, 4, 32, 32), text_features=text_features, cfg_scale=1.0, # No CFG for speed ) print(f" Output shape: {z_0.shape}") print(f" Output range: [{z_0.min():.3f}, {z_0.max():.3f}]") print(f" Output std: {z_0.std():.3f}") assert z_0.shape == (1, 4, 32, 32), f"Wrong output shape: {z_0.shape}" assert torch.isfinite(z_0).all(), "Output contains NaN/Inf!" print("\n✅ Sampling successful!") return True def test_tiny_decoder(): """Test the mobile-optimized VAE decoder""" print("\n" + "=" * 60) print("TEST 6: Tiny VAE Decoder") print("=" * 60) # Test f8 decoder (128x128 → 1024x1024) decoder_f8 = TinyVAEDecoder( in_channels=4, spatial_compression=8, base_channels=64 ) params_f8 = sum(p.numel() for p in decoder_f8.parameters()) z = torch.randn(1, 4, 128, 128) with torch.no_grad(): img = decoder_f8(z) print(f"f8 Decoder:") print(f" Parameters: {params_f8/1e6:.2f}M ({params_f8 * 2 / (1024**2):.1f}MB fp16)") print(f" Input: {z.shape} → Output: {img.shape}") # Test f32 decoder (32x32 → 1024x1024) decoder_f32 = TinyVAEDecoder( in_channels=32, spatial_compression=32, base_channels=64 ) params_f32 = sum(p.numel() for p in decoder_f32.parameters()) z32 = torch.randn(1, 32, 32, 32) with torch.no_grad(): img32 = decoder_f32(z32) print(f"\nf32 Decoder:") print(f" Parameters: {params_f32/1e6:.2f}M ({params_f32 * 2 / (1024**2):.1f}MB fp16)") print(f" Input: {z32.shape} → Output: {img32.shape}") print("\n✅ Tiny VAE Decoder test passed!") return True def test_noise_schedules(): """Test all noise schedule variants""" print("\n" + "=" * 60) print("TEST 7: Noise Schedules") print("=" * 60) for schedule in ['laplace', 'logit_normal', 'uniform']: scheduler = FlowMatchingScheduler(schedule=schedule) t = scheduler.sample_timesteps(10000, torch.device('cpu')) print(f"\n{schedule}:") print(f" Mean: {t.mean():.3f}, Std: {t.std():.3f}") print(f" Min: {t.min():.3f}, Max: {t.max():.3f}") # Check distribution shape bins = torch.histc(t, bins=10, min=0, max=1) bins = bins / bins.sum() print(f" Distribution (10 bins): {[f'{b:.2f}' for b in bins.tolist()]}") print("\n✅ All noise schedules working!") return True def test_full_pipeline(): """Test the complete pipeline including parameter summary""" print("\n" + "=" * 60) print("TEST 8: Full Pipeline Summary") print("=" * 60) pipeline = LiRAPipeline( config_name='small', latent_channels=32, spatial_compression=32, d_text=768, patch_size=1, ) counts = pipeline.count_parameters() print("\n🏗️ LiRA-Small Pipeline (f32 VAE, 1024px native):") print(f" Denoiser: {counts['total']/1e6:.1f}M params") print(f" Tiny Decoder: {counts['tiny_decoder']/1e6:.2f}M params") print(f" Total: {counts['total_with_decoder']/1e6:.1f}M params") print(f" Model size (fp16): {counts['total_with_decoder'] * 2 / (1024**2):.0f}MB") # Breakdown print(f"\n Component breakdown:") for k, v in counts.items(): if k not in ['total', 'total_with_decoder', 'tiny_decoder']: print(f" {k}: {v/1e6:.2f}M ({v/counts['total']*100:.1f}%)") # Memory estimate mem = estimate_memory_mb(pipeline, batch_size=1, img_size=1024, spatial_compression=32, latent_channels=32, dtype_bytes=2) print(f"\n 💾 Estimated inference memory:") print(f" Model params: {mem['params_mb']:.0f}MB") print(f" Latent tensors: {mem['latent_mb']:.1f}MB") print(f" Activations: {mem['activation_mb']:.1f}MB") print(f" Total: {mem['total_inference_mb']:.0f}MB") # Latent token analysis lat_h = 1024 // 32 lat_w = 1024 // 32 print(f"\n 📐 Latent space:") print(f" Image: 1024x1024px → Latent: {lat_h}x{lat_w} = {lat_h*lat_w} tokens") print(f" Complexity: O({lat_h*lat_w}) per block (linear, not quadratic)") print(f" Equivalent quadratic cost: O({lat_h*lat_w}²) = O({(lat_h*lat_w)**2:,})") print("\n✅ Full pipeline test passed!") return True if __name__ == '__main__': print("🎨 LiRA (Liquid Reasoning Artisan) - Architecture Tests") print("=" * 60) tests = [ test_model_creation, test_forward_pass, test_training_step, test_gradient_flow, test_sampling, test_tiny_decoder, test_noise_schedules, test_full_pipeline, ] passed = 0 failed = 0 for test_fn in tests: try: result = test_fn() if result: passed += 1 else: failed += 1 except Exception as e: print(f"\n❌ {test_fn.__name__} FAILED: {e}") import traceback traceback.print_exc() failed += 1 print("\n" + "=" * 60) print(f"RESULTS: {passed} passed, {failed} failed out of {len(tests)} tests") print("=" * 60)