asdf98 commited on
Commit
326c29a
·
verified ·
1 Parent(s): f742b0a

Upload iris/flow_matching.py

Browse files
Files changed (1) hide show
  1. iris/flow_matching.py +58 -0
iris/flow_matching.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow Matching training utilities for IRIS."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import math
6
+ from typing import Optional
7
+
8
+ DCAE_F32C32_SCALE = 0.41407
9
+
10
+
11
+ def sample_timesteps_logit_normal(batch_size, device, mean=0.0, std=1.0):
12
+ u = torch.normal(mean=mean, std=std, size=(batch_size,), device=device)
13
+ return torch.sigmoid(u).clamp(1e-5, 1.0 - 1e-5)
14
+
15
+
16
+ def sample_timesteps_uniform(batch_size, device):
17
+ return torch.rand(batch_size, device=device).clamp(1e-5, 1.0 - 1e-5)
18
+
19
+
20
+ def rectified_flow_forward(z_0, t, noise=None):
21
+ if noise is None:
22
+ noise = torch.randn_like(z_0)
23
+ t_expand = t.view(-1, 1, 1, 1)
24
+ z_t = t_expand * noise + (1.0 - t_expand) * z_0
25
+ target = noise - z_0
26
+ return z_t, target
27
+
28
+
29
+ def flow_matching_loss(model, z_0, context, num_iterations=4, timestep_sampling="logit_normal", scale_factor=DCAE_F32C32_SCALE):
30
+ B = z_0.shape[0]
31
+ device = z_0.device
32
+ z_0_scaled = z_0 * scale_factor
33
+ t = sample_timesteps_logit_normal(B, device) if timestep_sampling == "logit_normal" else sample_timesteps_uniform(B, device)
34
+ noise = torch.randn_like(z_0_scaled)
35
+ z_t, target = rectified_flow_forward(z_0_scaled, t, noise)
36
+ v_pred = model(z_t, t, context, num_iterations=num_iterations)
37
+ flow_loss = F.mse_loss(v_pred, target)
38
+ return {"loss": flow_loss, "flow_loss": flow_loss.detach()}
39
+
40
+
41
+ @torch.no_grad()
42
+ def euler_sample(model, noise, context, num_steps=20, num_iterations=4, cfg_scale=1.0, scale_factor=DCAE_F32C32_SCALE):
43
+ dt = -1.0 / num_steps
44
+ z_t = noise.clone()
45
+ for i in range(num_steps):
46
+ t_val = 1.0 - i / num_steps
47
+ t = torch.full((noise.shape[0],), t_val, device=noise.device, dtype=noise.dtype)
48
+ if cfg_scale > 1.0:
49
+ z_double = torch.cat([z_t, z_t], dim=0)
50
+ t_double = torch.cat([t, t], dim=0)
51
+ ctx_double = torch.cat([context, torch.zeros_like(context)], dim=0)
52
+ v_pred = model(z_double, t_double, ctx_double, num_iterations=num_iterations)
53
+ v_cond, v_uncond = v_pred.chunk(2, dim=0)
54
+ v_pred = v_uncond + cfg_scale * (v_cond - v_uncond)
55
+ else:
56
+ v_pred = model(z_t, t, context, num_iterations=num_iterations)
57
+ z_t = z_t + v_pred * dt
58
+ return z_t / scale_factor