LiRA / test_lira.py
asdf98's picture
Add test_lira.py
a02e7fd verified
"""
Comprehensive test suite for LiRA architecture.
Tests: model creation, forward pass, memory footprint, gradient flow,
training step, and inference sampling.
"""
import torch
import sys
import os
sys.path.insert(0, '/app')
from lira.model import LiRAModel, LiRAPipeline, TinyVAEDecoder, estimate_memory_mb
from lira.training import (
FlowMatchingScheduler, EMAModel, compute_loss,
LiRATrainingConfig, FlowDPMSolver
)
def test_model_creation():
"""Test all model configurations can be instantiated"""
print("=" * 60)
print("TEST 1: Model Creation & Parameter Counts")
print("=" * 60)
configs = ['tiny', 'small', 'base']
for config_name in configs:
# Use SD1.x-style VAE params for testing (4ch, f8)
model = LiRAModel(
config_name=config_name,
in_channels=4,
d_text=768,
patch_size=2,
)
counts = model.count_parameters()
total_m = counts['total'] / 1e6
print(f"\nLiRA-{config_name.capitalize()}:")
print(f" Total parameters: {total_m:.1f}M")
for k, v in counts.items():
if k != 'total':
print(f" {k}: {v/1e6:.2f}M ({v/counts['total']*100:.1f}%)")
# Memory estimate for 1024px with f8 VAE
mem = estimate_memory_mb(model, batch_size=1, img_size=1024,
spatial_compression=8, latent_channels=4, dtype_bytes=2)
print(f" Estimated inference memory (fp16): {mem['total_inference_mb']:.0f}MB")
print(f" Params: {mem['params_mb']:.0f}MB, Latent: {mem['latent_mb']:.1f}MB, Activations: {mem['activation_mb']:.1f}MB")
# Also test f32 VAE configuration
print(f"\n--- f32 VAE Configuration (DC-AE) ---")
model_f32 = LiRAModel(
config_name='small',
in_channels=32,
d_text=768,
patch_size=1,
)
counts_f32 = model_f32.count_parameters()
mem_f32 = estimate_memory_mb(model_f32, batch_size=1, img_size=1024,
spatial_compression=32, latent_channels=32, dtype_bytes=2)
print(f" LiRA-Small (f32 VAE): {counts_f32['total']/1e6:.1f}M params")
print(f" Estimated inference memory (fp16): {mem_f32['total_inference_mb']:.0f}MB")
print(f" Latent tokens: {(1024//32)**2} (32x32)")
print("\n✅ All model configurations created successfully!")
return True
def test_forward_pass():
"""Test forward pass with proper shapes"""
print("\n" + "=" * 60)
print("TEST 2: Forward Pass")
print("=" * 60)
model = LiRAModel(
config_name='tiny',
in_channels=4,
d_text=768,
patch_size=2,
)
model.eval()
# Simulate inputs
B = 2
# For 256px image with f8 VAE: 32x32 latent
z_t = torch.randn(B, 4, 32, 32)
t = torch.rand(B)
text_features = torch.randn(B, 77, 768) # CLIP-like
text_mask = torch.ones(B, 77, dtype=torch.bool)
print(f"Input shapes:")
print(f" z_t: {z_t.shape}")
print(f" t: {t.shape}")
print(f" text_features: {text_features.shape}")
with torch.no_grad():
v_pred, reason_info = model(z_t, t, text_features, text_mask)
print(f"\nOutput shapes:")
print(f" v_pred: {v_pred.shape}")
print(f" Reasoning steps: {reason_info['total_steps']}")
print(f" Discard rates: {[f'{r:.3f}' for r in reason_info['discard_rates']]}")
print(f" Stop values: {[f'{s:.3f}' for s in reason_info['stop_values']]}")
assert v_pred.shape == z_t.shape, f"Output shape mismatch: {v_pred.shape} vs {z_t.shape}"
print("\n✅ Forward pass successful!")
return True
def test_training_step():
"""Test a complete training step with loss computation"""
print("\n" + "=" * 60)
print("TEST 3: Training Step")
print("=" * 60)
config = LiRATrainingConfig(
model_config='tiny',
latent_channels=4,
spatial_compression=8,
d_text=768,
patch_size=2,
batch_size=2,
learning_rate=1e-4,
)
model = LiRAModel(
config_name=config.model_config,
in_channels=config.latent_channels,
d_text=config.d_text,
patch_size=config.patch_size,
)
model.train()
optimizer = torch.optim.AdamW(
model.parameters(), lr=config.learning_rate,
weight_decay=config.weight_decay
)
scheduler = FlowMatchingScheduler(schedule=config.noise_schedule)
ema = EMAModel(model, decay=config.ema_decay)
# Simulate data
B = 2
z_0 = torch.randn(B, 4, 32, 32) # Latent from VAE
text_features = torch.randn(B, 77, 768)
# Training loop (3 steps)
print("Running 3 training steps...")
losses = []
for step in range(3):
optimizer.zero_grad()
loss, info = compute_loss(
model, z_0, text_features, scheduler, config,
global_step=step
)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
optimizer.step()
ema.update(model)
losses.append(info['loss'])
print(f" Step {step}: loss={info['loss']:.4f}, "
f"mse={info['mse_loss']:.4f}, "
f"reason_steps={info['reason_steps']}, "
f"grad_norm={grad_norm:.4f}")
# Verify loss is finite and reasonable
assert all(torch.isfinite(torch.tensor(l)) for l in losses), "Loss is not finite!"
assert all(l < 100 for l in losses), "Loss is unreasonably large!"
print("\n✅ Training step successful!")
return True
def test_gradient_flow():
"""Verify gradients flow through all components"""
print("\n" + "=" * 60)
print("TEST 4: Gradient Flow Analysis")
print("=" * 60)
model = LiRAModel(
config_name='tiny',
in_channels=4,
d_text=768,
patch_size=2,
)
model.train()
z_t = torch.randn(1, 4, 32, 32)
t = torch.rand(1)
text = torch.randn(1, 77, 768)
v_pred, _ = model(z_t, t, text)
loss = v_pred.sum()
loss.backward()
# Check gradients in each component
components = {
'patch_embed': model.patch_embed,
'time_embed': model.time_embed,
'text_proj': model.text_proj,
'reasoning': model.reasoning,
'blocks[0]': model.blocks[0],
'blocks[-1]': model.blocks[-1],
}
for name, module in components.items():
has_grad = any(p.grad is not None and p.grad.abs().sum() > 0
for p in module.parameters() if p.requires_grad)
grad_norm = sum(p.grad.norm().item() for p in module.parameters()
if p.grad is not None)
status = "✅" if has_grad else "❌"
print(f" {status} {name}: grad_norm={grad_norm:.6f}")
print("\n✅ Gradient flow verified!")
return True
def test_sampling():
"""Test inference sampling"""
print("\n" + "=" * 60)
print("TEST 5: Inference Sampling")
print("=" * 60)
model = LiRAModel(
config_name='tiny',
in_channels=4,
d_text=768,
patch_size=2,
)
model.eval()
solver = FlowDPMSolver(num_steps=5, order=2) # Few steps for testing
text_features = torch.randn(1, 77, 768)
print("Sampling with DPM-Solver (5 steps)...")
z_0 = solver.sample(
model,
shape=(1, 4, 32, 32),
text_features=text_features,
cfg_scale=1.0, # No CFG for speed
)
print(f" Output shape: {z_0.shape}")
print(f" Output range: [{z_0.min():.3f}, {z_0.max():.3f}]")
print(f" Output std: {z_0.std():.3f}")
assert z_0.shape == (1, 4, 32, 32), f"Wrong output shape: {z_0.shape}"
assert torch.isfinite(z_0).all(), "Output contains NaN/Inf!"
print("\n✅ Sampling successful!")
return True
def test_tiny_decoder():
"""Test the mobile-optimized VAE decoder"""
print("\n" + "=" * 60)
print("TEST 6: Tiny VAE Decoder")
print("=" * 60)
# Test f8 decoder (128x128 → 1024x1024)
decoder_f8 = TinyVAEDecoder(
in_channels=4, spatial_compression=8, base_channels=64
)
params_f8 = sum(p.numel() for p in decoder_f8.parameters())
z = torch.randn(1, 4, 128, 128)
with torch.no_grad():
img = decoder_f8(z)
print(f"f8 Decoder:")
print(f" Parameters: {params_f8/1e6:.2f}M ({params_f8 * 2 / (1024**2):.1f}MB fp16)")
print(f" Input: {z.shape} → Output: {img.shape}")
# Test f32 decoder (32x32 → 1024x1024)
decoder_f32 = TinyVAEDecoder(
in_channels=32, spatial_compression=32, base_channels=64
)
params_f32 = sum(p.numel() for p in decoder_f32.parameters())
z32 = torch.randn(1, 32, 32, 32)
with torch.no_grad():
img32 = decoder_f32(z32)
print(f"\nf32 Decoder:")
print(f" Parameters: {params_f32/1e6:.2f}M ({params_f32 * 2 / (1024**2):.1f}MB fp16)")
print(f" Input: {z32.shape} → Output: {img32.shape}")
print("\n✅ Tiny VAE Decoder test passed!")
return True
def test_noise_schedules():
"""Test all noise schedule variants"""
print("\n" + "=" * 60)
print("TEST 7: Noise Schedules")
print("=" * 60)
for schedule in ['laplace', 'logit_normal', 'uniform']:
scheduler = FlowMatchingScheduler(schedule=schedule)
t = scheduler.sample_timesteps(10000, torch.device('cpu'))
print(f"\n{schedule}:")
print(f" Mean: {t.mean():.3f}, Std: {t.std():.3f}")
print(f" Min: {t.min():.3f}, Max: {t.max():.3f}")
# Check distribution shape
bins = torch.histc(t, bins=10, min=0, max=1)
bins = bins / bins.sum()
print(f" Distribution (10 bins): {[f'{b:.2f}' for b in bins.tolist()]}")
print("\n✅ All noise schedules working!")
return True
def test_full_pipeline():
"""Test the complete pipeline including parameter summary"""
print("\n" + "=" * 60)
print("TEST 8: Full Pipeline Summary")
print("=" * 60)
pipeline = LiRAPipeline(
config_name='small',
latent_channels=32,
spatial_compression=32,
d_text=768,
patch_size=1,
)
counts = pipeline.count_parameters()
print("\n🏗️ LiRA-Small Pipeline (f32 VAE, 1024px native):")
print(f" Denoiser: {counts['total']/1e6:.1f}M params")
print(f" Tiny Decoder: {counts['tiny_decoder']/1e6:.2f}M params")
print(f" Total: {counts['total_with_decoder']/1e6:.1f}M params")
print(f" Model size (fp16): {counts['total_with_decoder'] * 2 / (1024**2):.0f}MB")
# Breakdown
print(f"\n Component breakdown:")
for k, v in counts.items():
if k not in ['total', 'total_with_decoder', 'tiny_decoder']:
print(f" {k}: {v/1e6:.2f}M ({v/counts['total']*100:.1f}%)")
# Memory estimate
mem = estimate_memory_mb(pipeline, batch_size=1, img_size=1024,
spatial_compression=32, latent_channels=32, dtype_bytes=2)
print(f"\n 💾 Estimated inference memory:")
print(f" Model params: {mem['params_mb']:.0f}MB")
print(f" Latent tensors: {mem['latent_mb']:.1f}MB")
print(f" Activations: {mem['activation_mb']:.1f}MB")
print(f" Total: {mem['total_inference_mb']:.0f}MB")
# Latent token analysis
lat_h = 1024 // 32
lat_w = 1024 // 32
print(f"\n 📐 Latent space:")
print(f" Image: 1024x1024px → Latent: {lat_h}x{lat_w} = {lat_h*lat_w} tokens")
print(f" Complexity: O({lat_h*lat_w}) per block (linear, not quadratic)")
print(f" Equivalent quadratic cost: O({lat_h*lat_w}²) = O({(lat_h*lat_w)**2:,})")
print("\n✅ Full pipeline test passed!")
return True
if __name__ == '__main__':
print("🎨 LiRA (Liquid Reasoning Artisan) - Architecture Tests")
print("=" * 60)
tests = [
test_model_creation,
test_forward_pass,
test_training_step,
test_gradient_flow,
test_sampling,
test_tiny_decoder,
test_noise_schedules,
test_full_pipeline,
]
passed = 0
failed = 0
for test_fn in tests:
try:
result = test_fn()
if result:
passed += 1
else:
failed += 1
except Exception as e:
print(f"\n❌ {test_fn.__name__} FAILED: {e}")
import traceback
traceback.print_exc()
failed += 1
print("\n" + "=" * 60)
print(f"RESULTS: {passed} passed, {failed} failed out of {len(tests)} tests")
print("=" * 60)