| """Euler ODE sampler for flow matching inference.""" | |
| import torch | |
| def sample_flow(model, text_emb, shape, n_steps=12, device="cuda"): | |
| z = torch.randn(shape, device=device) | |
| B = z.shape[0] | |
| dt = 1.0 / n_steps | |
| for i in range(n_steps): | |
| t_val = 1.0 - i * dt | |
| t = torch.full((B,), t_val, device=device) | |
| v = model(z, text_emb, t) | |
| z = z - dt * v | |
| return z | |