File size: 2,269 Bytes
326c29a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Flow Matching training utilities for IRIS."""

import torch
import torch.nn.functional as F
import math
from typing import Optional

DCAE_F32C32_SCALE = 0.41407


def sample_timesteps_logit_normal(batch_size, device, mean=0.0, std=1.0):
    u = torch.normal(mean=mean, std=std, size=(batch_size,), device=device)
    return torch.sigmoid(u).clamp(1e-5, 1.0 - 1e-5)


def sample_timesteps_uniform(batch_size, device):
    return torch.rand(batch_size, device=device).clamp(1e-5, 1.0 - 1e-5)


def rectified_flow_forward(z_0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(z_0)
    t_expand = t.view(-1, 1, 1, 1)
    z_t = t_expand * noise + (1.0 - t_expand) * z_0
    target = noise - z_0
    return z_t, target


def flow_matching_loss(model, z_0, context, num_iterations=4, timestep_sampling="logit_normal", scale_factor=DCAE_F32C32_SCALE):
    B = z_0.shape[0]
    device = z_0.device
    z_0_scaled = z_0 * scale_factor
    t = sample_timesteps_logit_normal(B, device) if timestep_sampling == "logit_normal" else sample_timesteps_uniform(B, device)
    noise = torch.randn_like(z_0_scaled)
    z_t, target = rectified_flow_forward(z_0_scaled, t, noise)
    v_pred = model(z_t, t, context, num_iterations=num_iterations)
    flow_loss = F.mse_loss(v_pred, target)
    return {"loss": flow_loss, "flow_loss": flow_loss.detach()}


@torch.no_grad()
def euler_sample(model, noise, context, num_steps=20, num_iterations=4, cfg_scale=1.0, scale_factor=DCAE_F32C32_SCALE):
    dt = -1.0 / num_steps
    z_t = noise.clone()
    for i in range(num_steps):
        t_val = 1.0 - i / num_steps
        t = torch.full((noise.shape[0],), t_val, device=noise.device, dtype=noise.dtype)
        if cfg_scale > 1.0:
            z_double = torch.cat([z_t, z_t], dim=0)
            t_double = torch.cat([t, t], dim=0)
            ctx_double = torch.cat([context, torch.zeros_like(context)], dim=0)
            v_pred = model(z_double, t_double, ctx_double, num_iterations=num_iterations)
            v_cond, v_uncond = v_pred.chunk(2, dim=0)
            v_pred = v_uncond + cfg_scale * (v_cond - v_uncond)
        else:
            v_pred = model(z_t, t, context, num_iterations=num_iterations)
        z_t = z_t + v_pred * dt
    return z_t / scale_factor