File size: 415 Bytes
0c9146c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | """Euler ODE sampler for flow matching inference."""
import torch
@torch.no_grad()
def sample_flow(model, text_emb, shape, n_steps=12, device="cuda"):
z = torch.randn(shape, device=device)
B = z.shape[0]
dt = 1.0 / n_steps
for i in range(n_steps):
t_val = 1.0 - i * dt
t = torch.full((B,), t_val, device=device)
v = model(z, text_emb, t)
z = z - dt * v
return z
|