asdf98 commited on
Commit
831f50a
·
verified ·
1 Parent(s): b649f1d

Upload luminars/flow_loss.py

Browse files
Files changed (1) hide show
  1. luminars/flow_loss.py +19 -0
luminars/flow_loss.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow matching loss for LuminaRS training."""
2
+ import torch, torch.nn.functional as F
3
+
4
+ def flow_loss(model, z_clean, text_emb):
5
+ """
6
+ Optimal-transport conditional flow matching.
7
+ z_clean: (B,C,H,W) clean VAE latents
8
+ text_emb: (B,L,D) CLIP text embeddings
9
+ Returns: scalar MSE loss
10
+ """
11
+ B = z_clean.shape[0]
12
+ device = z_clean.device
13
+ t = torch.rand(B, device=device)
14
+ z_noise = torch.randn_like(z_clean)
15
+ t_view = t[:, None, None, None]
16
+ z_t = (1 - t_view) * z_noise + t_view * z_clean
17
+ v_target = z_clean - z_noise
18
+ v_pred = model(z_t, text_emb, t)
19
+ return F.mse_loss(v_pred, v_target)