Delete train_stage1.py, test_model.py
Browse files- test_model.py +0 -21
- train_stage1.py +0 -35
test_model.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
"""Quick sanity test of LuminaRS without heavy deps."""
|
| 2 |
-
import torch
|
| 3 |
-
from luminars.model import LuminaRS
|
| 4 |
-
from luminars.config import LuminaRSConfig
|
| 5 |
-
|
| 6 |
-
def test():
|
| 7 |
-
cfg = LuminaRSConfig()
|
| 8 |
-
model = LuminaRS(cfg)
|
| 9 |
-
n = sum(p.numel() for p in model.parameters())
|
| 10 |
-
print(f"Total params: {n/1e6:.1f}M")
|
| 11 |
-
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.1f}M")
|
| 12 |
-
|
| 13 |
-
bs, d, h, w = 2, cfg.latent_dim, cfg.latent_h, cfg.latent_w
|
| 14 |
-
z = torch.randn(bs, d, h, w)
|
| 15 |
-
text = torch.randn(bs, cfg.max_text_len, cfg.text_embed_dim)
|
| 16 |
-
t = torch.rand(bs)
|
| 17 |
-
out = model(z, text, t)
|
| 18 |
-
print(f"Forward OK, output shape: {out.shape}")
|
| 19 |
-
|
| 20 |
-
if __name__ == "__main__":
|
| 21 |
-
test()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_stage1.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
LuminaRS Stage 1: Core Flow-Matching Training
|
| 3 |
-
Trains the denoiser on art/illustration data with flow matching.
|
| 4 |
-
Colab A100 compatible. Uses frozen pretrained VAE + CLIP.
|
| 5 |
-
"""
|
| 6 |
-
import os, math, torch, torch.nn.functional as F
|
| 7 |
-
from torch.utils.data import DataLoader
|
| 8 |
-
from datasets import load_dataset
|
| 9 |
-
from torchvision import transforms
|
| 10 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
| 11 |
-
from diffusers import AutoencoderKL
|
| 12 |
-
from luminars.model import LuminaRS
|
| 13 |
-
from luminars.config import LuminaRSConfig
|
| 14 |
-
|
| 15 |
-
# ββ Flow Matching Loss ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
-
def flow_matching_loss(model, vae, clip, z0, text_emb):
|
| 17 |
-
"""Optimal-transport flow matching: v(x_t, t) = x1 - x0"""
|
| 18 |
-
B = z0.shape[0]
|
| 19 |
-
t = torch.rand(B, device=z0.device)
|
| 20 |
-
x1 = z0 # clean latent
|
| 21 |
-
x0 = torch.randn_like(z1) # noise
|
| 22 |
-
# Linear interpolation
|
| 23 |
-
xt = (1 - t[:,None,None,None]) * x0 + t[:,None,None,None] * x1
|
| 24 |
-
# Target velocity (straight line)
|
| 25 |
-
v_target = x1 - x0
|
| 26 |
-
v_pred = model(xt, text_emb, t)
|
| 27 |
-
return F.mse_loss(v_pred, v_target)
|
| 28 |
-
|
| 29 |
-
def flow_matching_loss(model, vae, clip, pixel_images, text_tokens):
|
| 30 |
-
"""Full pipeline: image -> VAE encode -> flow matching."""
|
| 31 |
-
with torch.no_grad():
|
| 32 |
-
latents = vae.encode(pixel_images).latent_dist.sample()
|
| 33 |
-
latents = latents * vae.config.scaling_factor
|
| 34 |
-
text_emb = clip(text_tokens).last_hidden_state
|
| 35 |
-
return flow_matching_loss(model, vae, clip, latents, text_emb), latents, text_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|