Upload iris/flow_matching.py
Browse files- iris/flow_matching.py +58 -0
iris/flow_matching.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flow Matching training utilities for IRIS."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import math
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
DCAE_F32C32_SCALE = 0.41407
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sample_timesteps_logit_normal(batch_size, device, mean=0.0, std=1.0):
|
| 12 |
+
u = torch.normal(mean=mean, std=std, size=(batch_size,), device=device)
|
| 13 |
+
return torch.sigmoid(u).clamp(1e-5, 1.0 - 1e-5)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def sample_timesteps_uniform(batch_size, device):
|
| 17 |
+
return torch.rand(batch_size, device=device).clamp(1e-5, 1.0 - 1e-5)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def rectified_flow_forward(z_0, t, noise=None):
|
| 21 |
+
if noise is None:
|
| 22 |
+
noise = torch.randn_like(z_0)
|
| 23 |
+
t_expand = t.view(-1, 1, 1, 1)
|
| 24 |
+
z_t = t_expand * noise + (1.0 - t_expand) * z_0
|
| 25 |
+
target = noise - z_0
|
| 26 |
+
return z_t, target
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def flow_matching_loss(model, z_0, context, num_iterations=4, timestep_sampling="logit_normal", scale_factor=DCAE_F32C32_SCALE):
|
| 30 |
+
B = z_0.shape[0]
|
| 31 |
+
device = z_0.device
|
| 32 |
+
z_0_scaled = z_0 * scale_factor
|
| 33 |
+
t = sample_timesteps_logit_normal(B, device) if timestep_sampling == "logit_normal" else sample_timesteps_uniform(B, device)
|
| 34 |
+
noise = torch.randn_like(z_0_scaled)
|
| 35 |
+
z_t, target = rectified_flow_forward(z_0_scaled, t, noise)
|
| 36 |
+
v_pred = model(z_t, t, context, num_iterations=num_iterations)
|
| 37 |
+
flow_loss = F.mse_loss(v_pred, target)
|
| 38 |
+
return {"loss": flow_loss, "flow_loss": flow_loss.detach()}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@torch.no_grad()
|
| 42 |
+
def euler_sample(model, noise, context, num_steps=20, num_iterations=4, cfg_scale=1.0, scale_factor=DCAE_F32C32_SCALE):
|
| 43 |
+
dt = -1.0 / num_steps
|
| 44 |
+
z_t = noise.clone()
|
| 45 |
+
for i in range(num_steps):
|
| 46 |
+
t_val = 1.0 - i / num_steps
|
| 47 |
+
t = torch.full((noise.shape[0],), t_val, device=noise.device, dtype=noise.dtype)
|
| 48 |
+
if cfg_scale > 1.0:
|
| 49 |
+
z_double = torch.cat([z_t, z_t], dim=0)
|
| 50 |
+
t_double = torch.cat([t, t], dim=0)
|
| 51 |
+
ctx_double = torch.cat([context, torch.zeros_like(context)], dim=0)
|
| 52 |
+
v_pred = model(z_double, t_double, ctx_double, num_iterations=num_iterations)
|
| 53 |
+
v_cond, v_uncond = v_pred.chunk(2, dim=0)
|
| 54 |
+
v_pred = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 55 |
+
else:
|
| 56 |
+
v_pred = model(z_t, t, context, num_iterations=num_iterations)
|
| 57 |
+
z_t = z_t + v_pred * dt
|
| 58 |
+
return z_t / scale_factor
|