krystv commited on
Commit
421b295
·
verified ·
1 Parent(s): 94fbb93

Upload test_model.py

Browse files
Files changed (1) hide show
  1. test_model.py +133 -0
test_model.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiquidDiffusion — Complete Test Suite
3
+ Tests model construction, forward/backward, training stability, and sampling.
4
+ Run: python test_model.py
5
+ """
6
+ import sys
7
+ import math
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ # Add parent directory to path
12
+ sys.path.insert(0, '.')
13
+
14
+ from liquid_diffusion.model import (
15
+ LiquidDiffusionUNet, liquid_diffusion_tiny,
16
+ liquid_diffusion_small, liquid_diffusion_base
17
+ )
18
+
19
+ print("=" * 70)
20
+ print("LiquidDiffusion: Novel Attention-Free Image Generation")
21
+ print("Based on Liquid Neural Networks (CfC) + Rectified Flow")
22
+ print("=" * 70)
23
+
24
+ all_passed = True
25
+
26
+ # Test 1: Model construction
27
+ print("\n--- Test 1: Model Construction & Parameter Count ---")
28
+ for name, factory in [("tiny", liquid_diffusion_tiny), ("small", liquid_diffusion_small), ("base", liquid_diffusion_base)]:
29
+ m = factory()
30
+ total, trainable = m.count_params()
31
+ print(f" {name:8s}: {total:>12,} params ({total/1e6:.1f}M)")
32
+ del m
33
+
34
+ # Test 2: Forward pass
35
+ print("\n--- Test 2: Forward Pass (multiple resolutions) ---")
36
+ model = liquid_diffusion_tiny()
37
+ for res in [32, 64, 128]:
38
+ x = torch.randn(2, 3, res, res)
39
+ t = torch.rand(2)
40
+ out = model(x, t)
41
+ ok = out.shape == x.shape
42
+ print(f" {res}x{res}: {'OK' if ok else 'FAIL'} shape={out.shape}")
43
+ if not ok: all_passed = False
44
+
45
+ # Test 3: Backward pass
46
+ print("\n--- Test 3: Backward Pass (gradient flow) ---")
47
+ model = liquid_diffusion_tiny()
48
+ x = torch.randn(2, 3, 64, 64)
49
+ t = torch.rand(2)
50
+ out = model(x, t)
51
+ loss = out.mean()
52
+ loss.backward()
53
+ num_params_with_grad = sum(1 for p in model.parameters() if p.grad is not None)
54
+ nan_grads = sum(1 for p in model.parameters() if p.grad is not None and torch.isnan(p.grad).any())
55
+ print(f" Params with gradients: {num_params_with_grad}")
56
+ print(f" NaN gradients: {nan_grads}")
57
+ if nan_grads > 0: all_passed = False
58
+
59
+ # Test 4: Training stability (20 steps)
60
+ print("\n--- Test 4: Training Stability (20 steps, random data) ---")
61
+ model = liquid_diffusion_tiny()
62
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
63
+ losses = []
64
+ for step in range(20):
65
+ model.train()
66
+ x0 = torch.randn(4, 3, 64, 64)
67
+ x1 = torch.randn_like(x0)
68
+ t_val = torch.rand(4)
69
+ x_t = (1 - t_val[:, None, None, None]) * x0 + t_val[:, None, None, None] * x1
70
+ v_target = x1 - x0
71
+ v_pred = model(x_t, t_val)
72
+ loss = F.mse_loss(v_pred, v_target)
73
+ optimizer.zero_grad()
74
+ loss.backward()
75
+ gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
76
+ optimizer.step()
77
+ losses.append(loss.item())
78
+ if step % 5 == 0:
79
+ print(f" Step {step:3d}: loss={loss.item():.4f}, grad_norm={gn.item():.4f}")
80
+
81
+ stable = all(not math.isnan(l) and not math.isinf(l) for l in losses)
82
+ not_exploding = max(losses) < 100
83
+ print(f" Stable (no NaN/Inf): {'OK' if stable else 'FAIL'}")
84
+ print(f" Not exploding: {'OK' if not_exploding else 'FAIL'} (max={max(losses):.4f})")
85
+ if not stable or not not_exploding: all_passed = False
86
+
87
+ # Test 5: Sampling
88
+ print("\n--- Test 5: Sampling (10 Euler steps) ---")
89
+ model.eval()
90
+ with torch.no_grad():
91
+ z = torch.randn(2, 3, 64, 64)
92
+ for i in range(10, 0, -1):
93
+ t_s = torch.full((2,), i / 10.0)
94
+ v = model(z, t_s)
95
+ z = z - v * 0.1
96
+ z = z.clamp(-1, 1)
97
+ print(f" Shape: {z.shape}, range: [{z.min():.3f}, {z.max():.3f}]")
98
+
99
+ # Test 6: Timestep sensitivity
100
+ print("\n--- Test 6: Timestep Sensitivity ---")
101
+ model.eval()
102
+ x = torch.randn(1, 3, 64, 64)
103
+ for t_val in [0.01, 0.25, 0.5, 0.75, 0.99]:
104
+ with torch.no_grad():
105
+ out = model(x, torch.tensor([t_val]))
106
+ print(f" t={t_val:.2f}: mean={out.mean():.6f}, std={out.std():.6f}")
107
+
108
+ # Test 7: Architecture properties
109
+ print("\n--- Test 7: Architecture Properties ---")
110
+ m = liquid_diffusion_tiny()
111
+ total_blocks = (sum(len(s) for s in m.encoder_blocks) + len(m.bottleneck) + sum(len(s) for s in m.decoder_blocks))
112
+ print(f" Attention layers: 0")
113
+ print(f" Sequential loops: 0")
114
+ print(f" CfC blocks: {total_blocks}")
115
+ print(f" Training objective: Rectified Flow (MSE velocity)")
116
+
117
+ # Test 8: VRAM estimates
118
+ print("\n--- Test 8: VRAM Estimates (fp16 training) ---")
119
+ for name, factory, res, bs in [
120
+ ("tiny 256px bs4", liquid_diffusion_tiny, 256, 4),
121
+ ("small 256px bs4", liquid_diffusion_small, 256, 4),
122
+ ("base 256px bs2", liquid_diffusion_base, 256, 2),
123
+ ("tiny 512px bs2", liquid_diffusion_tiny, 512, 2),
124
+ ]:
125
+ m = factory()
126
+ tp = sum(p.numel() for p in m.parameters())
127
+ est = (tp * 2 + tp * 4 + tp * 8) / 1e9 + bs * 3 * res * res * 4 * len(m.channels) * max(m.channels) / 1e9 * 0.3
128
+ print(f" {name:20s}: {tp/1e6:.1f}M params, ~{est:.1f}GB VRAM")
129
+ del m
130
+
131
+ print("\n" + "=" * 70)
132
+ print(f"ALL TESTS {'PASSED' if all_passed else 'SOME FAILURES'}")
133
+ print("=" * 70)