Upload luminars/flow_loss.py
Browse files- luminars/flow_loss.py +19 -0
luminars/flow_loss.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flow matching loss for LuminaRS training."""
|
| 2 |
+
import torch, torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
def flow_loss(model, z_clean, text_emb):
|
| 5 |
+
"""
|
| 6 |
+
Optimal-transport conditional flow matching.
|
| 7 |
+
z_clean: (B,C,H,W) clean VAE latents
|
| 8 |
+
text_emb: (B,L,D) CLIP text embeddings
|
| 9 |
+
Returns: scalar MSE loss
|
| 10 |
+
"""
|
| 11 |
+
B = z_clean.shape[0]
|
| 12 |
+
device = z_clean.device
|
| 13 |
+
t = torch.rand(B, device=device)
|
| 14 |
+
z_noise = torch.randn_like(z_clean)
|
| 15 |
+
t_view = t[:, None, None, None]
|
| 16 |
+
z_t = (1 - t_view) * z_noise + t_view * z_clean
|
| 17 |
+
v_target = z_clean - z_noise
|
| 18 |
+
v_pred = model(z_t, text_emb, t)
|
| 19 |
+
return F.mse_loss(v_pred, v_target)
|