Upload test_model.py
Browse files- test_model.py +133 -0
test_model.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LiquidDiffusion — Complete Test Suite
|
| 3 |
+
Tests model construction, forward/backward, training stability, and sampling.
|
| 4 |
+
Run: python test_model.py
|
| 5 |
+
"""
|
| 6 |
+
import sys
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
# Add parent directory to path
|
| 12 |
+
sys.path.insert(0, '.')
|
| 13 |
+
|
| 14 |
+
from liquid_diffusion.model import (
|
| 15 |
+
LiquidDiffusionUNet, liquid_diffusion_tiny,
|
| 16 |
+
liquid_diffusion_small, liquid_diffusion_base
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
print("=" * 70)
|
| 20 |
+
print("LiquidDiffusion: Novel Attention-Free Image Generation")
|
| 21 |
+
print("Based on Liquid Neural Networks (CfC) + Rectified Flow")
|
| 22 |
+
print("=" * 70)
|
| 23 |
+
|
| 24 |
+
all_passed = True
|
| 25 |
+
|
| 26 |
+
# Test 1: Model construction
|
| 27 |
+
print("\n--- Test 1: Model Construction & Parameter Count ---")
|
| 28 |
+
for name, factory in [("tiny", liquid_diffusion_tiny), ("small", liquid_diffusion_small), ("base", liquid_diffusion_base)]:
|
| 29 |
+
m = factory()
|
| 30 |
+
total, trainable = m.count_params()
|
| 31 |
+
print(f" {name:8s}: {total:>12,} params ({total/1e6:.1f}M)")
|
| 32 |
+
del m
|
| 33 |
+
|
| 34 |
+
# Test 2: Forward pass
|
| 35 |
+
print("\n--- Test 2: Forward Pass (multiple resolutions) ---")
|
| 36 |
+
model = liquid_diffusion_tiny()
|
| 37 |
+
for res in [32, 64, 128]:
|
| 38 |
+
x = torch.randn(2, 3, res, res)
|
| 39 |
+
t = torch.rand(2)
|
| 40 |
+
out = model(x, t)
|
| 41 |
+
ok = out.shape == x.shape
|
| 42 |
+
print(f" {res}x{res}: {'OK' if ok else 'FAIL'} shape={out.shape}")
|
| 43 |
+
if not ok: all_passed = False
|
| 44 |
+
|
| 45 |
+
# Test 3: Backward pass
|
| 46 |
+
print("\n--- Test 3: Backward Pass (gradient flow) ---")
|
| 47 |
+
model = liquid_diffusion_tiny()
|
| 48 |
+
x = torch.randn(2, 3, 64, 64)
|
| 49 |
+
t = torch.rand(2)
|
| 50 |
+
out = model(x, t)
|
| 51 |
+
loss = out.mean()
|
| 52 |
+
loss.backward()
|
| 53 |
+
num_params_with_grad = sum(1 for p in model.parameters() if p.grad is not None)
|
| 54 |
+
nan_grads = sum(1 for p in model.parameters() if p.grad is not None and torch.isnan(p.grad).any())
|
| 55 |
+
print(f" Params with gradients: {num_params_with_grad}")
|
| 56 |
+
print(f" NaN gradients: {nan_grads}")
|
| 57 |
+
if nan_grads > 0: all_passed = False
|
| 58 |
+
|
| 59 |
+
# Test 4: Training stability (20 steps)
|
| 60 |
+
print("\n--- Test 4: Training Stability (20 steps, random data) ---")
|
| 61 |
+
model = liquid_diffusion_tiny()
|
| 62 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
|
| 63 |
+
losses = []
|
| 64 |
+
for step in range(20):
|
| 65 |
+
model.train()
|
| 66 |
+
x0 = torch.randn(4, 3, 64, 64)
|
| 67 |
+
x1 = torch.randn_like(x0)
|
| 68 |
+
t_val = torch.rand(4)
|
| 69 |
+
x_t = (1 - t_val[:, None, None, None]) * x0 + t_val[:, None, None, None] * x1
|
| 70 |
+
v_target = x1 - x0
|
| 71 |
+
v_pred = model(x_t, t_val)
|
| 72 |
+
loss = F.mse_loss(v_pred, v_target)
|
| 73 |
+
optimizer.zero_grad()
|
| 74 |
+
loss.backward()
|
| 75 |
+
gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 76 |
+
optimizer.step()
|
| 77 |
+
losses.append(loss.item())
|
| 78 |
+
if step % 5 == 0:
|
| 79 |
+
print(f" Step {step:3d}: loss={loss.item():.4f}, grad_norm={gn.item():.4f}")
|
| 80 |
+
|
| 81 |
+
stable = all(not math.isnan(l) and not math.isinf(l) for l in losses)
|
| 82 |
+
not_exploding = max(losses) < 100
|
| 83 |
+
print(f" Stable (no NaN/Inf): {'OK' if stable else 'FAIL'}")
|
| 84 |
+
print(f" Not exploding: {'OK' if not_exploding else 'FAIL'} (max={max(losses):.4f})")
|
| 85 |
+
if not stable or not not_exploding: all_passed = False
|
| 86 |
+
|
| 87 |
+
# Test 5: Sampling
|
| 88 |
+
print("\n--- Test 5: Sampling (10 Euler steps) ---")
|
| 89 |
+
model.eval()
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
z = torch.randn(2, 3, 64, 64)
|
| 92 |
+
for i in range(10, 0, -1):
|
| 93 |
+
t_s = torch.full((2,), i / 10.0)
|
| 94 |
+
v = model(z, t_s)
|
| 95 |
+
z = z - v * 0.1
|
| 96 |
+
z = z.clamp(-1, 1)
|
| 97 |
+
print(f" Shape: {z.shape}, range: [{z.min():.3f}, {z.max():.3f}]")
|
| 98 |
+
|
| 99 |
+
# Test 6: Timestep sensitivity
|
| 100 |
+
print("\n--- Test 6: Timestep Sensitivity ---")
|
| 101 |
+
model.eval()
|
| 102 |
+
x = torch.randn(1, 3, 64, 64)
|
| 103 |
+
for t_val in [0.01, 0.25, 0.5, 0.75, 0.99]:
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
out = model(x, torch.tensor([t_val]))
|
| 106 |
+
print(f" t={t_val:.2f}: mean={out.mean():.6f}, std={out.std():.6f}")
|
| 107 |
+
|
| 108 |
+
# Test 7: Architecture properties
|
| 109 |
+
print("\n--- Test 7: Architecture Properties ---")
|
| 110 |
+
m = liquid_diffusion_tiny()
|
| 111 |
+
total_blocks = (sum(len(s) for s in m.encoder_blocks) + len(m.bottleneck) + sum(len(s) for s in m.decoder_blocks))
|
| 112 |
+
print(f" Attention layers: 0")
|
| 113 |
+
print(f" Sequential loops: 0")
|
| 114 |
+
print(f" CfC blocks: {total_blocks}")
|
| 115 |
+
print(f" Training objective: Rectified Flow (MSE velocity)")
|
| 116 |
+
|
| 117 |
+
# Test 8: VRAM estimates
|
| 118 |
+
print("\n--- Test 8: VRAM Estimates (fp16 training) ---")
|
| 119 |
+
for name, factory, res, bs in [
|
| 120 |
+
("tiny 256px bs4", liquid_diffusion_tiny, 256, 4),
|
| 121 |
+
("small 256px bs4", liquid_diffusion_small, 256, 4),
|
| 122 |
+
("base 256px bs2", liquid_diffusion_base, 256, 2),
|
| 123 |
+
("tiny 512px bs2", liquid_diffusion_tiny, 512, 2),
|
| 124 |
+
]:
|
| 125 |
+
m = factory()
|
| 126 |
+
tp = sum(p.numel() for p in m.parameters())
|
| 127 |
+
est = (tp * 2 + tp * 4 + tp * 8) / 1e9 + bs * 3 * res * res * 4 * len(m.channels) * max(m.channels) / 1e9 * 0.3
|
| 128 |
+
print(f" {name:20s}: {tp/1e6:.1f}M params, ~{est:.1f}GB VRAM")
|
| 129 |
+
del m
|
| 130 |
+
|
| 131 |
+
print("\n" + "=" * 70)
|
| 132 |
+
print(f"ALL TESTS {'PASSED' if all_passed else 'SOME FAILURES'}")
|
| 133 |
+
print("=" * 70)
|