asdf98 commited on
Commit
2fd257b
·
verified ·
1 Parent(s): 7ada470

Upload test_model.py

Browse files
Files changed (1) hide show
  1. test_model.py +21 -0
test_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quick sanity test of LuminaRS without heavy deps."""
2
+ import torch
3
+ from luminars.model import LuminaRS
4
+ from luminars.config import LuminaRSConfig
5
+
6
+ def test():
7
+ cfg = LuminaRSConfig()
8
+ model = LuminaRS(cfg)
9
+ n = sum(p.numel() for p in model.parameters())
10
+ print(f"Total params: {n/1e6:.1f}M")
11
+ print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.1f}M")
12
+
13
+ bs, d, h, w = 2, cfg.latent_dim, cfg.latent_h, cfg.latent_w
14
+ z = torch.randn(bs, d, h, w)
15
+ text = torch.randn(bs, cfg.max_text_len, cfg.text_embed_dim)
16
+ t = torch.rand(bs)
17
+ out = model(z, text, t)
18
+ print(f"Forward OK, output shape: {out.shape}")
19
+
20
+ if __name__ == "__main__":
21
+ test()