krystv commited on
Commit
029ca89
Β·
verified Β·
1 Parent(s): 943ab10

Upload test_syntax.py

Browse files
Files changed (1) hide show
  1. test_syntax.py +104 -0
test_syntax.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick syntax and import test for LiquidFlow.
3
+ Run: python test_syntax.py
4
+ """
5
+ import sys
6
+ import os
7
+
8
+ # Test 1: All files parse correctly
9
+ print("=== Test 1: Syntax Check ===")
10
+ modules = [
11
+ 'liquid_flow/__init__.py',
12
+ 'liquid_flow/cfc_cell.py',
13
+ 'liquid_flow/mamba2_ssd.py',
14
+ 'liquid_flow/liquid_flow_block.py',
15
+ 'liquid_flow/generator.py',
16
+ 'liquid_flow/vae_wrapper.py',
17
+ 'liquid_flow/physics_loss.py',
18
+ 'train.py',
19
+ ]
20
+
21
+ for module in modules:
22
+ with open(module, 'r') as f:
23
+ code = f.read()
24
+ compile(code, module, 'exec')
25
+ print(f" βœ“ {module}")
26
+
27
+ # Test 2: Module imports
28
+ print("\n=== Test 2: Import Check ===")
29
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+
34
+ from liquid_flow.cfc_cell import CfCCell, CfCBlock
35
+ print(" βœ“ CfC imports")
36
+
37
+ from liquid_flow.mamba2_ssd import Mamba2SSD, Mamba2Block
38
+ print(" βœ“ Mamba-2 imports")
39
+
40
+ from liquid_flow.liquid_flow_block import LiquidMambaBlock, LiquidFlowBackbone
41
+ print(" βœ“ LiquidFlow block imports")
42
+
43
+ from liquid_flow.generator import LiquidFlowGenerator, create_liquidflow
44
+ print(" βœ“ Generator imports")
45
+
46
+ from liquid_flow.vae_wrapper import TAESDWrapper, SDVAEWrapper
47
+ print(" βœ“ VAE wrapper imports")
48
+
49
+ from liquid_flow.physics_loss import PhysicsRegularizer, DDIMEstimator
50
+ print(" βœ“ Physics loss imports")
51
+
52
+ # Test 3: Forward pass
53
+ print("\n=== Test 3: Forward Pass ===")
54
+
55
+ # CfCCell
56
+ cell = CfCCell(dim=64)
57
+ x = torch.randn(2, 64)
58
+ h = cell(x)
59
+ assert h.shape == x.shape
60
+ print(f" βœ“ CfCCell: {x.shape} -> {h.shape}")
61
+
62
+ # Mamba2SSD
63
+ ssd = Mamba2SSD(dim=64, d_state=8, expand=2)
64
+ x_seq = torch.randn(2, 256, 64)
65
+ out = ssd(x_seq)
66
+ assert out.shape == x_seq.shape
67
+ print(f" βœ“ Mamba2SSD: {x_seq.shape} -> {out.shape}")
68
+
69
+ # LiquidMambaBlock (2D)
70
+ lm = LiquidMambaBlock(dim=64, d_state=8, expand=2)
71
+ x_2d = torch.randn(2, 64, 16, 16)
72
+ out = lm(x_2d)
73
+ assert out.shape == x_2d.shape
74
+ print(f" βœ“ LiquidMambaBlock: {x_2d.shape} -> {out.shape}")
75
+
76
+ # Full backbone
77
+ backbone = LiquidFlowBackbone(in_channels=4, hidden_dim=64, num_stages=2, blocks_per_stage=2)
78
+ x = torch.randn(2, 4, 32, 32)
79
+ t = torch.tensor([500, 750])
80
+ out = backbone(x, t)
81
+ assert out.shape == x.shape
82
+ print(f" βœ“ Backbone: {x.shape} -> {out.shape}")
83
+
84
+ # Generator
85
+ model = create_liquidflow(variant='tiny', image_size=128)
86
+ x = torch.randn(2, 4, 16, 16)
87
+ t = torch.tensor([500, 750])
88
+ out = model(x, t)
89
+ assert out.shape == x.shape
90
+ print(f" βœ“ Generator: {x.shape} -> {out.shape}")
91
+
92
+ # Physics loss
93
+ physics = PhysicsRegularizer()
94
+ x0 = torch.randn(2, 3, 32, 32)
95
+ total, losses = physics(x0)
96
+ print(f" βœ“ Physics Loss: total={total.item():.4f}")
97
+
98
+ # Count params
99
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
100
+ print(f"\n{'='*60}")
101
+ print(f"ALL TESTS PASSED! βœ“")
102
+ print(f"Tiny model params: {n_params:,} ({n_params/1e6:.1f}M)")
103
+ print(f"Model compatible with Colab/Kaggle free tier")
104
+ print(f"{'='*60}")