File size: 415 Bytes
0c9146c
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""Euler ODE sampler for flow matching inference."""
import torch

@torch.no_grad()
def sample_flow(model, text_emb, shape, n_steps=12, device="cuda"):
    z = torch.randn(shape, device=device)
    B = z.shape[0]
    dt = 1.0 / n_steps
    for i in range(n_steps):
        t_val = 1.0 - i * dt
        t = torch.full((B,), t_val, device=device)
        v = model(z, text_emb, t)
        z = z - dt * v
    return z