| """Flow matching loss for LuminaRS training.""" | |
| import torch, torch.nn.functional as F | |
| def flow_loss(model, z_clean, text_emb): | |
| """ | |
| Optimal-transport conditional flow matching. | |
| z_clean: (B,C,H,W) clean VAE latents | |
| text_emb: (B,L,D) CLIP text embeddings | |
| Returns: scalar MSE loss | |
| """ | |
| B = z_clean.shape[0] | |
| device = z_clean.device | |
| t = torch.rand(B, device=device) | |
| z_noise = torch.randn_like(z_clean) | |
| t_view = t[:, None, None, None] | |
| z_t = (1 - t_view) * z_noise + t_view * z_clean | |
| v_target = z_clean - z_noise | |
| v_pred = model(z_t, text_emb, t) | |
| return F.mse_loss(v_pred, v_target) | |