asdf98 commited on
Commit
7e03082
Β·
verified Β·
1 Parent(s): 3afddfb

Upload train_stage1.py

Browse files
Files changed (1) hide show
  1. train_stage1.py +35 -0
train_stage1.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LuminaRS Stage 1: Core Flow-Matching Training
3
+ Trains the denoiser on art/illustration data with flow matching.
4
+ Colab A100 compatible. Uses frozen pretrained VAE + CLIP.
5
+ """
6
+ import os, math, torch, torch.nn.functional as F
7
+ from torch.utils.data import DataLoader
8
+ from datasets import load_dataset
9
+ from torchvision import transforms
10
+ from transformers import CLIPTextModel, CLIPTokenizer
11
+ from diffusers import AutoencoderKL
12
+ from luminars.model import LuminaRS
13
+ from luminars.config import LuminaRSConfig
14
+
15
+ # ── Flow Matching Loss ──────────────────────────────────────────────────
16
+ def flow_matching_loss(model, vae, clip, z0, text_emb):
17
+ """Optimal-transport flow matching: v(x_t, t) = x1 - x0"""
18
+ B = z0.shape[0]
19
+ t = torch.rand(B, device=z0.device)
20
+ x1 = z0 # clean latent
21
+ x0 = torch.randn_like(z1) # noise
22
+ # Linear interpolation
23
+ xt = (1 - t[:,None,None,None]) * x0 + t[:,None,None,None] * x1
24
+ # Target velocity (straight line)
25
+ v_target = x1 - x0
26
+ v_pred = model(xt, text_emb, t)
27
+ return F.mse_loss(v_pred, v_target)
28
+
29
+ def flow_matching_loss(model, vae, clip, pixel_images, text_tokens):
30
+ """Full pipeline: image -> VAE encode -> flow matching."""
31
+ with torch.no_grad():
32
+ latents = vae.encode(pixel_images).latent_dist.sample()
33
+ latents = latents * vae.config.scaling_factor
34
+ text_emb = clip(text_tokens).last_hidden_state
35
+ return flow_matching_loss(model, vae, clip, latents, text_emb), latents, text_emb