Tinman-SmolOmni-MLA-256M / ode_solver.py
TinmanLabSL's picture
ONNX export: ode_solver.py
b4251b3 verified
raw
history blame contribute delete
854 Bytes
"""
ONNX Runtime ODE Solver for SmolOmni Flow-Matching Image Generation
Usage:
import onnxruntime as ort
sess_ctx = ort.InferenceSession("smolomni_256M_gen_context.onnx")
sess_flow = ort.InferenceSession("smolomni_256M_flow_head_step.onnx")
def generate_image(prompt_tokens, num_steps=50):
ctx = sess_ctx.run(None, {"input_ids": prompt_tokens})[0]
latents = np.random.randn(1, 4, 32, 32).astype(np.float32)
dt = 1.0 / num_steps
for i in range(num_steps):
t = np.array([i * dt * 1000], dtype=np.float32)
velocity = sess_flow.run(None, {
"noisy_latents": latents,
"timestep": t,
"context": ctx,
})[0]
latents = latents + velocity * dt
return latents # Pass to VAE decoder for final image
"""