Upload test_model.py
Browse files- test_model.py +21 -0
test_model.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Quick sanity test of LuminaRS without heavy deps."""
|
| 2 |
+
import torch
|
| 3 |
+
from luminars.model import LuminaRS
|
| 4 |
+
from luminars.config import LuminaRSConfig
|
| 5 |
+
|
| 6 |
+
def test():
|
| 7 |
+
cfg = LuminaRSConfig()
|
| 8 |
+
model = LuminaRS(cfg)
|
| 9 |
+
n = sum(p.numel() for p in model.parameters())
|
| 10 |
+
print(f"Total params: {n/1e6:.1f}M")
|
| 11 |
+
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.1f}M")
|
| 12 |
+
|
| 13 |
+
bs, d, h, w = 2, cfg.latent_dim, cfg.latent_h, cfg.latent_w
|
| 14 |
+
z = torch.randn(bs, d, h, w)
|
| 15 |
+
text = torch.randn(bs, cfg.max_text_len, cfg.text_embed_dim)
|
| 16 |
+
t = torch.rand(bs)
|
| 17 |
+
out = model(z, text, t)
|
| 18 |
+
print(f"Forward OK, output shape: {out.shape}")
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
test()
|