"""Flow Matching training utilities for IRIS.""" import torch import torch.nn.functional as F import math from typing import Optional DCAE_F32C32_SCALE = 0.41407 def sample_timesteps_logit_normal(batch_size, device, mean=0.0, std=1.0): u = torch.normal(mean=mean, std=std, size=(batch_size,), device=device) return torch.sigmoid(u).clamp(1e-5, 1.0 - 1e-5) def sample_timesteps_uniform(batch_size, device): return torch.rand(batch_size, device=device).clamp(1e-5, 1.0 - 1e-5) def rectified_flow_forward(z_0, t, noise=None): if noise is None: noise = torch.randn_like(z_0) t_expand = t.view(-1, 1, 1, 1) z_t = t_expand * noise + (1.0 - t_expand) * z_0 target = noise - z_0 return z_t, target def flow_matching_loss(model, z_0, context, num_iterations=4, timestep_sampling="logit_normal", scale_factor=DCAE_F32C32_SCALE): B = z_0.shape[0] device = z_0.device z_0_scaled = z_0 * scale_factor t = sample_timesteps_logit_normal(B, device) if timestep_sampling == "logit_normal" else sample_timesteps_uniform(B, device) noise = torch.randn_like(z_0_scaled) z_t, target = rectified_flow_forward(z_0_scaled, t, noise) v_pred = model(z_t, t, context, num_iterations=num_iterations) flow_loss = F.mse_loss(v_pred, target) return {"loss": flow_loss, "flow_loss": flow_loss.detach()} @torch.no_grad() def euler_sample(model, noise, context, num_steps=20, num_iterations=4, cfg_scale=1.0, scale_factor=DCAE_F32C32_SCALE): dt = -1.0 / num_steps z_t = noise.clone() for i in range(num_steps): t_val = 1.0 - i / num_steps t = torch.full((noise.shape[0],), t_val, device=noise.device, dtype=noise.dtype) if cfg_scale > 1.0: z_double = torch.cat([z_t, z_t], dim=0) t_double = torch.cat([t, t], dim=0) ctx_double = torch.cat([context, torch.zeros_like(context)], dim=0) v_pred = model(z_double, t_double, ctx_double, num_iterations=num_iterations) v_cond, v_uncond = v_pred.chunk(2, dim=0) v_pred = v_uncond + cfg_scale * (v_cond - v_uncond) else: v_pred = model(z_t, t, context, num_iterations=num_iterations) z_t = z_t + v_pred * dt return z_t / scale_factor