"""Flow matching loss for LuminaRS training.""" import torch, torch.nn.functional as F def flow_loss(model, z_clean, text_emb): """ Optimal-transport conditional flow matching. z_clean: (B,C,H,W) clean VAE latents text_emb: (B,L,D) CLIP text embeddings Returns: scalar MSE loss """ B = z_clean.shape[0] device = z_clean.device t = torch.rand(B, device=device) z_noise = torch.randn_like(z_clean) t_view = t[:, None, None, None] z_t = (1 - t_view) * z_noise + t_view * z_clean v_target = z_clean - z_noise v_pred = model(z_t, text_emb, t) return F.mse_loss(v_pred, v_target)