LiquidFlow-Gen / test_syntax.py
krystv's picture
Upload test_syntax.py
029ca89 verified
"""
Quick syntax and import test for LiquidFlow.
Run: python test_syntax.py
"""
import sys
import os
# Test 1: All files parse correctly
print("=== Test 1: Syntax Check ===")
modules = [
'liquid_flow/__init__.py',
'liquid_flow/cfc_cell.py',
'liquid_flow/mamba2_ssd.py',
'liquid_flow/liquid_flow_block.py',
'liquid_flow/generator.py',
'liquid_flow/vae_wrapper.py',
'liquid_flow/physics_loss.py',
'train.py',
]
for module in modules:
with open(module, 'r') as f:
code = f.read()
compile(code, module, 'exec')
print(f" βœ“ {module}")
# Test 2: Module imports
print("\n=== Test 2: Import Check ===")
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
import torch.nn as nn
from liquid_flow.cfc_cell import CfCCell, CfCBlock
print(" βœ“ CfC imports")
from liquid_flow.mamba2_ssd import Mamba2SSD, Mamba2Block
print(" βœ“ Mamba-2 imports")
from liquid_flow.liquid_flow_block import LiquidMambaBlock, LiquidFlowBackbone
print(" βœ“ LiquidFlow block imports")
from liquid_flow.generator import LiquidFlowGenerator, create_liquidflow
print(" βœ“ Generator imports")
from liquid_flow.vae_wrapper import TAESDWrapper, SDVAEWrapper
print(" βœ“ VAE wrapper imports")
from liquid_flow.physics_loss import PhysicsRegularizer, DDIMEstimator
print(" βœ“ Physics loss imports")
# Test 3: Forward pass
print("\n=== Test 3: Forward Pass ===")
# CfCCell
cell = CfCCell(dim=64)
x = torch.randn(2, 64)
h = cell(x)
assert h.shape == x.shape
print(f" βœ“ CfCCell: {x.shape} -> {h.shape}")
# Mamba2SSD
ssd = Mamba2SSD(dim=64, d_state=8, expand=2)
x_seq = torch.randn(2, 256, 64)
out = ssd(x_seq)
assert out.shape == x_seq.shape
print(f" βœ“ Mamba2SSD: {x_seq.shape} -> {out.shape}")
# LiquidMambaBlock (2D)
lm = LiquidMambaBlock(dim=64, d_state=8, expand=2)
x_2d = torch.randn(2, 64, 16, 16)
out = lm(x_2d)
assert out.shape == x_2d.shape
print(f" βœ“ LiquidMambaBlock: {x_2d.shape} -> {out.shape}")
# Full backbone
backbone = LiquidFlowBackbone(in_channels=4, hidden_dim=64, num_stages=2, blocks_per_stage=2)
x = torch.randn(2, 4, 32, 32)
t = torch.tensor([500, 750])
out = backbone(x, t)
assert out.shape == x.shape
print(f" βœ“ Backbone: {x.shape} -> {out.shape}")
# Generator
model = create_liquidflow(variant='tiny', image_size=128)
x = torch.randn(2, 4, 16, 16)
t = torch.tensor([500, 750])
out = model(x, t)
assert out.shape == x.shape
print(f" βœ“ Generator: {x.shape} -> {out.shape}")
# Physics loss
physics = PhysicsRegularizer()
x0 = torch.randn(2, 3, 32, 32)
total, losses = physics(x0)
print(f" βœ“ Physics Loss: total={total.item():.4f}")
# Count params
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n{'='*60}")
print(f"ALL TESTS PASSED! βœ“")
print(f"Tiny model params: {n_params:,} ({n_params/1e6:.1f}M)")
print(f"Model compatible with Colab/Kaggle free tier")
print(f"{'='*60}")