iris-image-gen / iris /flow_matching.py
asdf98's picture
Upload iris/flow_matching.py
326c29a verified
"""Flow Matching training utilities for IRIS."""
import torch
import torch.nn.functional as F
import math
from typing import Optional
DCAE_F32C32_SCALE = 0.41407
def sample_timesteps_logit_normal(batch_size, device, mean=0.0, std=1.0):
u = torch.normal(mean=mean, std=std, size=(batch_size,), device=device)
return torch.sigmoid(u).clamp(1e-5, 1.0 - 1e-5)
def sample_timesteps_uniform(batch_size, device):
return torch.rand(batch_size, device=device).clamp(1e-5, 1.0 - 1e-5)
def rectified_flow_forward(z_0, t, noise=None):
if noise is None:
noise = torch.randn_like(z_0)
t_expand = t.view(-1, 1, 1, 1)
z_t = t_expand * noise + (1.0 - t_expand) * z_0
target = noise - z_0
return z_t, target
def flow_matching_loss(model, z_0, context, num_iterations=4, timestep_sampling="logit_normal", scale_factor=DCAE_F32C32_SCALE):
B = z_0.shape[0]
device = z_0.device
z_0_scaled = z_0 * scale_factor
t = sample_timesteps_logit_normal(B, device) if timestep_sampling == "logit_normal" else sample_timesteps_uniform(B, device)
noise = torch.randn_like(z_0_scaled)
z_t, target = rectified_flow_forward(z_0_scaled, t, noise)
v_pred = model(z_t, t, context, num_iterations=num_iterations)
flow_loss = F.mse_loss(v_pred, target)
return {"loss": flow_loss, "flow_loss": flow_loss.detach()}
@torch.no_grad()
def euler_sample(model, noise, context, num_steps=20, num_iterations=4, cfg_scale=1.0, scale_factor=DCAE_F32C32_SCALE):
dt = -1.0 / num_steps
z_t = noise.clone()
for i in range(num_steps):
t_val = 1.0 - i / num_steps
t = torch.full((noise.shape[0],), t_val, device=noise.device, dtype=noise.dtype)
if cfg_scale > 1.0:
z_double = torch.cat([z_t, z_t], dim=0)
t_double = torch.cat([t, t], dim=0)
ctx_double = torch.cat([context, torch.zeros_like(context)], dim=0)
v_pred = model(z_double, t_double, ctx_double, num_iterations=num_iterations)
v_cond, v_uncond = v_pred.chunk(2, dim=0)
v_pred = v_uncond + cfg_scale * (v_cond - v_uncond)
else:
v_pred = model(z_t, t, context, num_iterations=num_iterations)
z_t = z_t + v_pred * dt
return z_t / scale_factor