LuminaRS / luminars /flow_loss.py
asdf98's picture
Upload luminars/flow_loss.py
831f50a verified
"""Flow matching loss for LuminaRS training."""
import torch, torch.nn.functional as F
def flow_loss(model, z_clean, text_emb):
"""
Optimal-transport conditional flow matching.
z_clean: (B,C,H,W) clean VAE latents
text_emb: (B,L,D) CLIP text embeddings
Returns: scalar MSE loss
"""
B = z_clean.shape[0]
device = z_clean.device
t = torch.rand(B, device=device)
z_noise = torch.randn_like(z_clean)
t_view = t[:, None, None, None]
z_t = (1 - t_view) * z_noise + t_view * z_clean
v_target = z_clean - z_noise
v_pred = model(z_t, text_emb, t)
return F.mse_loss(v_pred, v_target)