asdf98 commited on
Commit
e3431a4
Β·
verified Β·
1 Parent(s): 7e03082

Delete train_stage1.py, test_model.py

Browse files
Files changed (2) hide show
  1. test_model.py +0 -21
  2. train_stage1.py +0 -35
test_model.py DELETED
@@ -1,21 +0,0 @@
1
- """Quick sanity test of LuminaRS without heavy deps."""
2
- import torch
3
- from luminars.model import LuminaRS
4
- from luminars.config import LuminaRSConfig
5
-
6
- def test():
7
- cfg = LuminaRSConfig()
8
- model = LuminaRS(cfg)
9
- n = sum(p.numel() for p in model.parameters())
10
- print(f"Total params: {n/1e6:.1f}M")
11
- print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.1f}M")
12
-
13
- bs, d, h, w = 2, cfg.latent_dim, cfg.latent_h, cfg.latent_w
14
- z = torch.randn(bs, d, h, w)
15
- text = torch.randn(bs, cfg.max_text_len, cfg.text_embed_dim)
16
- t = torch.rand(bs)
17
- out = model(z, text, t)
18
- print(f"Forward OK, output shape: {out.shape}")
19
-
20
- if __name__ == "__main__":
21
- test()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_stage1.py DELETED
@@ -1,35 +0,0 @@
1
- """
2
- LuminaRS Stage 1: Core Flow-Matching Training
3
- Trains the denoiser on art/illustration data with flow matching.
4
- Colab A100 compatible. Uses frozen pretrained VAE + CLIP.
5
- """
6
- import os, math, torch, torch.nn.functional as F
7
- from torch.utils.data import DataLoader
8
- from datasets import load_dataset
9
- from torchvision import transforms
10
- from transformers import CLIPTextModel, CLIPTokenizer
11
- from diffusers import AutoencoderKL
12
- from luminars.model import LuminaRS
13
- from luminars.config import LuminaRSConfig
14
-
15
- # ── Flow Matching Loss ──────────────────────────────────────────────────
16
- def flow_matching_loss(model, vae, clip, z0, text_emb):
17
- """Optimal-transport flow matching: v(x_t, t) = x1 - x0"""
18
- B = z0.shape[0]
19
- t = torch.rand(B, device=z0.device)
20
- x1 = z0 # clean latent
21
- x0 = torch.randn_like(z1) # noise
22
- # Linear interpolation
23
- xt = (1 - t[:,None,None,None]) * x0 + t[:,None,None,None] * x1
24
- # Target velocity (straight line)
25
- v_target = x1 - x0
26
- v_pred = model(xt, text_emb, t)
27
- return F.mse_loss(v_pred, v_target)
28
-
29
- def flow_matching_loss(model, vae, clip, pixel_images, text_tokens):
30
- """Full pipeline: image -> VAE encode -> flow matching."""
31
- with torch.no_grad():
32
- latents = vae.encode(pixel_images).latent_dist.sample()
33
- latents = latents * vae.config.scaling_factor
34
- text_emb = clip(text_tokens).last_hidden_state
35
- return flow_matching_loss(model, vae, clip, latents, text_emb), latents, text_emb