asdf98 commited on
Commit
0c9146c
·
verified ·
1 Parent(s): 785614a

Upload luminars/sampler.py

Browse files
Files changed (1) hide show
  1. luminars/sampler.py +14 -0
luminars/sampler.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Euler ODE sampler for flow matching inference."""
2
+ import torch
3
+
4
+ @torch.no_grad()
5
+ def sample_flow(model, text_emb, shape, n_steps=12, device="cuda"):
6
+ z = torch.randn(shape, device=device)
7
+ B = z.shape[0]
8
+ dt = 1.0 / n_steps
9
+ for i in range(n_steps):
10
+ t_val = 1.0 - i * dt
11
+ t = torch.full((B,), t_val, device=device)
12
+ v = model(z, text_emb, t)
13
+ z = z - dt * v
14
+ return z