"""Euler ODE sampler for flow matching inference.""" import torch @torch.no_grad() def sample_flow(model, text_emb, shape, n_steps=12, device="cuda"): z = torch.randn(shape, device=device) B = z.shape[0] dt = 1.0 / n_steps for i in range(n_steps): t_val = 1.0 - i * dt t = torch.full((B,), t_val, device=device) v = model(z, text_emb, t) z = z - dt * v return z