LuminaRS / luminars /sampler.py
asdf98's picture
Upload luminars/sampler.py
0c9146c verified
"""Euler ODE sampler for flow matching inference."""
import torch
@torch.no_grad()
def sample_flow(model, text_emb, shape, n_steps=12, device="cuda"):
z = torch.randn(shape, device=device)
B = z.shape[0]
dt = 1.0 / n_steps
for i in range(n_steps):
t_val = 1.0 - i * dt
t = torch.full((B,), t_val, device=device)
v = model(z, text_emb, t)
z = z - dt * v
return z