Upload luminars/sampler.py
Browse files- luminars/sampler.py +14 -0
luminars/sampler.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Euler ODE sampler for flow matching inference."""
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
@torch.no_grad()
|
| 5 |
+
def sample_flow(model, text_emb, shape, n_steps=12, device="cuda"):
|
| 6 |
+
z = torch.randn(shape, device=device)
|
| 7 |
+
B = z.shape[0]
|
| 8 |
+
dt = 1.0 / n_steps
|
| 9 |
+
for i in range(n_steps):
|
| 10 |
+
t_val = 1.0 - i * dt
|
| 11 |
+
t = torch.full((B,), t_val, device=device)
|
| 12 |
+
v = model(z, text_emb, t)
|
| 13 |
+
z = z - dt * v
|
| 14 |
+
return z
|