sana-wm / app.py
multimodalart's picture
multimodalart HF Staff
Update app.py
128f6ae verified
"""ZeroGPU Gradio demo for SANA-WM (bidirectional).
Loads ``Efficient-Large-Model/SANA-WM_bidirectional`` (Stage-1 DiT + LTX-2 VAE
+ gemma-2-2b-it text encoder, no refiner) and lets the user roll out a
camera-controlled image-to-video clip from a WASD/IJKL action queue.
"""
from __future__ import annotations
# ``spaces`` must be imported BEFORE anything that initialises CUDA (including
# torch, mmcv, etc.). mmcv pulls torch in transitively, so even our runtime
# install path below would touch CUDA emulation first if ``spaces`` weren't
# already imported.
import spaces # noqa: E402,F401
import os
import sys
import time
import tempfile
from pathlib import Path
# Vendor the Sana PR branch as a sibling directory. We bundle the
# feat/sana-wm branch into ./Sana when deploying to a Space.
ROOT = Path(__file__).resolve().parent
SANA_DIR = ROOT / "Sana"
if str(SANA_DIR) not in sys.path:
sys.path.insert(0, str(SANA_DIR))
# Must be set before any Sana / xformers import (see inference_sana_wm.py).
os.environ.setdefault("DISABLE_XFORMERS", "1")
# mmcv 1.7.2's setup.py imports ``pkg_resources`` (removed in setuptools 80+),
# so it cannot be installed via Spaces' isolated pip build env. We follow the
# Sana ``environment_setup.sh`` recipe and install it here on first cold-start:
# pin setuptools<80, then build mmcv with ``--no-build-isolation`` so the
# build sees the env's setuptools. Subsequent imports are no-ops.
def _ensure_mmcv() -> None:
try:
import mmcv # noqa: F401
return
except ImportError:
pass
import subprocess
import sys as _sys
print("[startup] mmcv not found — installing with --no-build-isolation...", flush=True)
subprocess.check_call(
[_sys.executable, "-m", "pip", "install", "--quiet", "setuptools<80", "wheel"]
)
subprocess.check_call(
[
_sys.executable, "-m", "pip", "install", "--quiet",
"--no-build-isolation", "mmcv==1.7.2",
]
)
import mmcv # noqa: F401
print("[startup] mmcv installed.", flush=True)
_ensure_mmcv()
# flash-attn matches Sana's ``environment_setup.sh`` pin (``flash-attn>=2.7.0``).
# ``FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE`` skips the nvcc kernel build, so the
# wheel installs in seconds — but ``flash_attn/__init__.py`` then unconditionally
# does ``import flash_attn_2_cuda``, which doesn't exist. We pre-stub that
# CUDA-extension module in ``sys.modules`` so the import succeeds; SANA-WM uses
# Triton GDN paths and never reaches the actual CUDA kernels.
def _ensure_flash_attn() -> None:
import types as _types
import sys as _sys
if "flash_attn_2_cuda" not in _sys.modules:
_sys.modules["flash_attn_2_cuda"] = _types.ModuleType("flash_attn_2_cuda")
try:
import flash_attn # noqa: F401
return
except ImportError:
pass
import subprocess
print("[startup] flash-attn not found — installing (CUDA build skipped)...", flush=True)
env = dict(os.environ)
env["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
subprocess.check_call(
[
_sys.executable, "-m", "pip", "install", "--quiet",
"--no-build-isolation", "flash-attn>=2.7.0",
],
env=env,
)
if "flash_attn" in _sys.modules:
del _sys.modules["flash_attn"]
import flash_attn # noqa: F401
print("[startup] flash-attn installed.", flush=True)
_ensure_flash_attn()
import gradio as gr
import imageio.v3 as iio
import numpy as np
import torch
from PIL import Image
# Sana-WM imports — registered by importing nets, then the pipeline module.
import diffusion.model.nets # noqa: F401
from inference_video_scripts.inference_sana_wm import (
DEFAULT_ROTATION_SPEED_DEG,
DEFAULT_TRANSLATION_SPEED,
GenerationParams,
InferenceConfig,
SanaWMPipeline,
TARGET_HEIGHT,
TARGET_WIDTH,
action_string_to_c2w,
estimate_intrinsics_with_pi3x,
resize_and_center_crop,
transform_intrinsics_for_crop,
)
from diffusion.utils.action_overlay import apply_overlay
from diffusion.utils.logger import get_root_logger
import pyrallis
from sana.tools import resolve_hf_path
HF_REPO = "Efficient-Large-Model/SANA-WM_bidirectional"
CONFIG_URI = f"hf://{HF_REPO}/config.yaml"
MODEL_URI = f"hf://{HF_REPO}/dit/sana_wm_1600m_720p.safetensors"
LOGGER = get_root_logger()
# ---------------------------------------------------------------------------
# Pipeline — built at module load time. ZeroGPU's CUDA emulation lets us place
# tensors on "cuda" without a real GPU; @spaces.GPU swaps in a real device for
# the duration of the decorated call. Lazy-loading inside the decorator is
# strongly discouraged on ZeroGPU.
# ---------------------------------------------------------------------------
print("[startup] Building SanaWMPipeline (downloads ~17 GB on first launch)...", flush=True)
_t0 = time.time()
PIPELINE_CONFIG: InferenceConfig = pyrallis.parse(
config_class=InferenceConfig, config_path=resolve_hf_path(CONFIG_URI), args=[]
)
PIPELINE = SanaWMPipeline(
config=PIPELINE_CONFIG,
model_path=resolve_hf_path(MODEL_URI),
device="cuda",
refiner=None,
logger=LOGGER,
)
print(f"[startup] Pipeline ready in {time.time() - _t0:.1f}s.", flush=True)
# ---------------------------------------------------------------------------
# Action-queue helpers (queue list <-> DSL string)
# ---------------------------------------------------------------------------
ALLOWED_KEYS = set("wasdijkl")
def _normalize_queue(queue) -> list[dict]:
"""Accept either a list[dict] (from the UI), a DSL string (from examples /
serialized state), or anything falsy. Always returns a list of dicts."""
if not queue:
return []
if isinstance(queue, str):
out = []
for seg in queue.split(","):
seg = seg.strip()
if not seg or "-" not in seg:
continue
keys, frames = seg.rsplit("-", 1)
try:
out.append({"keys": keys.lower(), "frames": max(1, int(frames))})
except ValueError:
continue
return out
if isinstance(queue, list):
return [s for s in queue if isinstance(s, dict)]
return []
def queue_to_dsl(queue) -> str:
"""[{keys: 'wj', frames: 16}, ...] -> 'wj-16,...' — also accepts a DSL str."""
queue = _normalize_queue(queue)
if not queue:
return ""
parts = []
for seg in queue:
keys = (seg.get("keys") or "").lower().strip()
keys = "".join(sorted(set(c for c in keys if c in ALLOWED_KEYS))) or "none"
frames = max(1, int(seg.get("frames", 1)))
parts.append(f"{keys}-{frames}")
return ",".join(parts)
def queue_total_frames(queue) -> int:
queue = _normalize_queue(queue)
return sum(max(1, int(s.get("frames", 1))) for s in queue)
def _snap_num_frames(n: int, stride: int = 8, upper_bound: int | None = None) -> int:
"""LTX-2 VAE wants num_frames = 8*k + 1."""
if n < 1:
return 1
if (n - 1) % stride == 0:
return n
floor_cand = n - ((n - 1) % stride)
ceil_cand = floor_cand + stride
snapped = floor_cand if (n - floor_cand) < (ceil_cand - n) else ceil_cand
if upper_bound is not None and snapped > upper_bound:
snapped = floor_cand
return max(snapped, 1)
# ---------------------------------------------------------------------------
# tqdm <-> gr.Progress bridge
#
# Sana's samplers do ``from tqdm import tqdm`` at module load, so the symbol
# binds before our patch runs. We swap that *bound name* on the sampler module
# with a subclass that forwards progress on every update, then restore on exit.
# Avoids ``track_tqdm=True``, which breaks inside ZeroGPU subprocesses.
# ---------------------------------------------------------------------------
import contextlib
import importlib
_SANA_TQDM_TARGETS = [
"diffusion.scheduler.flow_euler_sampler",
"diffusion.scheduler.dpm_solver",
"diffusion.model.dpm_solver",
]
@contextlib.contextmanager
def _bridge_sana_tqdm(progress: "gr.Progress", *, span=(0.4, 0.9), desc_prefix="DiT"):
"""Forward Sana's tqdm ticks to the Gradio progress bar."""
import tqdm as _tqdm_mod
lo, hi = span
class _BridgedTqdm(_tqdm_mod.tqdm):
def update(self, n=1):
ret = super().update(n)
try:
if self.total:
frac = max(0.0, min(1.0, self.n / self.total))
progress(lo + (hi - lo) * frac, desc=f"{desc_prefix} step {self.n}/{self.total}")
except Exception:
pass
return ret
saved: dict[str, object] = {}
for mod_name in _SANA_TQDM_TARGETS:
try:
mod = importlib.import_module(mod_name)
except Exception:
continue
if hasattr(mod, "tqdm"):
saved[mod_name] = mod.tqdm
mod.tqdm = _BridgedTqdm
try:
yield
finally:
for mod_name, original in saved.items():
try:
importlib.import_module(mod_name).tqdm = original
except Exception:
pass
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
@spaces.GPU(duration=210)
def infer(
image: Image.Image,
prompt: str,
queue_state,
num_frames: int = 161,
fps: int = 16,
steps: int = 40,
cfg_scale: float = 5.0,
translation_speed: float = DEFAULT_TRANSLATION_SPEED,
rotation_speed_deg: float = DEFAULT_ROTATION_SPEED_DEG,
seed: int = 42,
show_overlay: bool = True,
progress: gr.Progress = gr.Progress(),
):
if image is None:
raise gr.Error("Please upload a first-frame image.")
if not prompt or not prompt.strip():
raise gr.Error("Please enter a prompt.")
dsl = queue_to_dsl(queue_state or [])
if not dsl:
raise gr.Error("Action queue is empty — add at least one segment.")
progress(0.05, desc="Rolling out trajectory")
c2w_full = action_string_to_c2w(
dsl,
translation_speed=translation_speed,
rotation_speed_deg=rotation_speed_deg,
)
upper = c2w_full.shape[0]
snapped = _snap_num_frames(min(int(num_frames), upper), stride=8, upper_bound=upper)
if snapped < 9:
raise gr.Error(
f"Need at least 9 frames after LTX-2 snapping (8k+1); your queue rolled out {upper} frames."
)
c2w = c2w_full[:snapped]
progress(0.1, desc="Preparing image")
cropped, src_size, resized_size, crop_offset = resize_and_center_crop(image)
progress(0.15, desc="Estimating intrinsics (Pi3X)")
device = torch.device("cuda")
intr_one = estimate_intrinsics_with_pi3x(image, device, LOGGER)
intr_src = np.broadcast_to(intr_one, (snapped, 4)).copy()
intrinsics_vec4 = transform_intrinsics_for_crop(intr_src, src_size, resized_size, crop_offset)
params = GenerationParams(
num_frames=snapped,
fps=int(fps),
step=int(steps),
cfg_scale=float(cfg_scale),
seed=int(seed),
sampling_algo="flow_euler_ltx",
)
progress(0.4, desc=f"Sampling {snapped} frames in {steps} steps")
t0 = time.time()
with _bridge_sana_tqdm(progress, span=(0.4, 0.9), desc_prefix="DiT"):
out = PIPELINE.generate(cropped, prompt.strip(), c2w, intrinsics_vec4, params)
LOGGER.info(f"Sampling done in {time.time() - t0:.1f}s")
video_hwc = out["video"]
if show_overlay:
progress(0.92, desc="Compositing action overlay")
video_hwc = apply_overlay(video_hwc, out["c2w"])
out_path = Path(tempfile.mkdtemp()) / "sana_wm.mp4"
iio.imwrite(out_path, video_hwc, fps=params.fps)
return str(out_path), dsl
# ---------------------------------------------------------------------------
# ActionQueue: custom gr.HTML component
# - 8 directional buttons (WASD = translate, IJKL = rotate) + "idle"
# - Frames-per-tap input
# - Removable queue chips
# - Three.js top-down trajectory preview that mirrors action_string_to_c2w
# ---------------------------------------------------------------------------
HTML_TEMPLATE = """
<div class="aq-wrap">
<div class="aq-pads">
<div class="aq-pad">
<div class="aq-pad-label">Translate (WASD)</div>
<div class="aq-grid">
<div></div><button class="aq-btn aq-tr" data-k="w">W</button><div></div>
<button class="aq-btn aq-tr" data-k="a">A</button>
<button class="aq-btn aq-idle" data-k="none">·</button>
<button class="aq-btn aq-tr" data-k="d">D</button>
<div></div><button class="aq-btn aq-tr" data-k="s">S</button><div></div>
</div>
</div>
<div class="aq-pad">
<div class="aq-pad-label">Rotate (IJKL)</div>
<div class="aq-grid">
<div></div><button class="aq-btn aq-ro" data-k="i">I ↑</button><div></div>
<button class="aq-btn aq-ro" data-k="j">J ←</button>
<div></div>
<button class="aq-btn aq-ro" data-k="l">L →</button>
<div></div><button class="aq-btn aq-ro" data-k="k">K ↓</button><div></div>
</div>
</div>
</div>
<div class="aq-row">
<label>Frames per tap
<input type="number" id="aq-frames" min="1" max="320" step="1" value="16" />
</label>
<span class="aq-spacer"></span>
<button class="aq-btn aq-secondary" id="aq-back">⌫ Back</button>
<button class="aq-btn aq-danger" id="aq-clear">Clear</button>
</div>
<div class="aq-queue-box">
<div class="aq-queue-header">
<span>Queue</span><span id="aq-total"></span>
</div>
<div id="aq-queue" class="aq-queue"></div>
<div id="aq-dsl" class="aq-dsl"></div>
</div>
</div>
"""
CSS_TEMPLATE = """
.aq-wrap { display:flex; flex-direction:column; gap:10px; font-size:13px; }
.aq-pads { display:flex; gap:14px; }
.aq-pad { flex:1; background:#161b22; border:1px solid #30363d; border-radius:10px; padding:10px; }
.aq-pad-label { color:#8b949e; font-size:11px; text-transform:uppercase; letter-spacing:0.05em; margin-bottom:6px; text-align:center; }
.aq-grid { display:grid; grid-template-columns: repeat(3, 1fr); gap:6px; }
.aq-btn {
background:#21262d; color:#e6edf3; border:1px solid #30363d; border-radius:8px;
padding:10px 0; cursor:pointer; font-weight:600; font-family:inherit; font-size:13px;
transition:transform .08s ease, background .12s ease;
}
.aq-btn:hover { background:#2d333b; }
.aq-btn:active { transform:scale(0.96); }
.aq-tr { background:#1f6feb22; border-color:#1f6feb55; color:#79b8ff; }
.aq-tr:hover { background:#1f6feb44; }
.aq-ro { background:#a371f722; border-color:#a371f755; color:#c9a4ff; }
.aq-ro:hover { background:#a371f744; }
.aq-idle { background:#30363d; color:#8b949e; }
.aq-secondary { background:#30363d; }
.aq-danger { background:#da363322; border-color:#da363355; color:#ff7b72; }
.aq-row { display:flex; align-items:center; gap:10px; }
.aq-row label { color:#8b949e; font-size:12px; display:flex; align-items:center; gap:6px; }
.aq-row input[type=number] {
background:#0d1117; color:#e6edf3; border:1px solid #30363d; border-radius:6px;
padding:4px 6px; width:70px; font-family:inherit;
}
.aq-spacer { flex:1; }
.aq-queue-box { background:#0d1117; border:1px solid #30363d; border-radius:10px; padding:10px; }
.aq-queue-header { display:flex; justify-content:space-between; color:#8b949e; font-size:11px; text-transform:uppercase; margin-bottom:6px; }
.aq-queue { display:flex; flex-wrap:wrap; gap:6px; min-height:32px; }
.aq-chip {
display:inline-flex; align-items:center; gap:6px; padding:4px 8px; border-radius:14px;
background:#21262d; border:1px solid #30363d; color:#e6edf3; font-family:ui-monospace,monospace; font-size:12px;
}
.aq-chip button { background:transparent; border:none; color:#8b949e; cursor:pointer; font-size:14px; line-height:1; padding:0; }
.aq-chip button:hover { color:#ff7b72; }
.aq-empty { color:#6e7681; font-style:italic; padding:4px 0; }
.aq-dsl { margin-top:8px; padding:6px 8px; background:#010409; border-radius:6px; color:#7ee787; font-family:ui-monospace,monospace; font-size:12px; word-break:break-all; min-height:18px; }
"""
JS_ON_LOAD = r"""
(() => {
const root = element;
const framesInput = root.querySelector('#aq-frames');
const backBtn = root.querySelector('#aq-back');
const clearBtn = root.querySelector('#aq-clear');
const queueEl = root.querySelector('#aq-queue');
const totalEl = root.querySelector('#aq-total');
const dslEl = root.querySelector('#aq-dsl');
// ----- queue state ------------------------------------------------------
let queue = Array.isArray(props.value) ? props.value.slice() : [];
const MAX_FRAMES = 160; // num_frames slider max is 161 = MAX_FRAMES + 1
function render() {
queueEl.innerHTML = '';
if (queue.length === 0) {
const e = document.createElement('span');
e.className = 'aq-empty';
e.textContent = 'No segments yet — tap a button above.';
queueEl.appendChild(e);
} else {
queue.forEach((seg, i) => {
const chip = document.createElement('span');
chip.className = 'aq-chip';
chip.textContent = `${seg.keys || 'none'}-${seg.frames}`;
const x = document.createElement('button');
x.textContent = '×';
x.title = 'Remove';
x.addEventListener('click', (ev) => {
ev.stopPropagation();
queue.splice(i, 1);
commit();
});
chip.appendChild(x);
queueEl.appendChild(chip);
});
}
const total = queue.reduce((s, x) => s + (x.frames | 0), 0);
totalEl.textContent = total
? `${queue.length} segment${queue.length > 1 ? 's' : ''} · ${total}/${MAX_FRAMES} frames`
: `0/${MAX_FRAMES} frames`;
totalEl.style.color = total >= MAX_FRAMES ? '#ff7b72' : '';
dslEl.textContent = queue.length
? queue.map(s => `${(s.keys || 'none')}-${s.frames}`).join(',')
: '(empty)';
}
function commit() {
render();
props.value = queue.slice();
trigger('change', queue.slice());
}
function appendSegment(keys) {
const requested = Math.max(1, parseInt(framesInput.value || '16', 10));
const current = queue.reduce((s, x) => s + (x.frames | 0), 0);
const remaining = MAX_FRAMES - current;
if (remaining <= 0) return; // queue full
const frames = Math.min(requested, remaining);
// Coalesce consecutive identical segments.
const last = queue[queue.length - 1];
if (last && last.keys === keys) {
last.frames += frames;
} else {
queue.push({ keys, frames });
}
commit();
}
root.querySelectorAll('.aq-btn[data-k]').forEach(btn => {
btn.addEventListener('click', () => appendSegment(btn.dataset.k));
});
backBtn.addEventListener('click', () => { queue.pop(); commit(); });
clearBtn.addEventListener('click', () => { queue = []; commit(); });
// Keyboard shortcuts (when focus is anywhere in the queue block).
root.tabIndex = 0;
root.addEventListener('keydown', (e) => {
const k = e.key.toLowerCase();
if ('wasdijkl'.includes(k)) { appendSegment(k); e.preventDefault(); }
else if (k === ' ') { appendSegment('none'); e.preventDefault(); }
else if (k === 'backspace') { queue.pop(); commit(); e.preventDefault(); }
});
// React to external value changes (Examples click, programmatic updates).
function parseDsl(s) {
return s.split(',').map(seg => {
const idx = seg.lastIndexOf('-');
if (idx < 0) return null;
const keys = seg.slice(0, idx).toLowerCase().trim();
const frames = parseInt(seg.slice(idx + 1), 10);
if (!keys || !Number.isFinite(frames) || frames < 1) return null;
return { keys, frames };
}).filter(Boolean);
}
watch('value', () => {
let incoming;
if (Array.isArray(props.value)) {
incoming = props.value;
} else if (typeof props.value === 'string') {
incoming = parseDsl(props.value);
} else {
incoming = [];
}
// Skip if the incoming value matches our local state — avoids feedback loops
// from our own commit() calls.
if (JSON.stringify(incoming) === JSON.stringify(queue)) return;
queue = incoming.map(s => ({ keys: (s.keys || 'none'), frames: Math.max(1, s.frames | 0) }));
render();
});
render();
})();
"""
class ActionQueue(gr.HTML):
"""WASD/IJKL action-queue input. value = list[{keys: str, frames: int}]."""
def __init__(self, value=None, **kwargs):
if value is None:
value = []
super().__init__(
value=value,
html_template=HTML_TEMPLATE,
css_template=CSS_TEMPLATE,
js_on_load=JS_ON_LOAD,
**kwargs,
)
def api_info(self):
return {
"type": "array",
"items": {
"type": "object",
"properties": {
"keys": {"type": "string"},
"frames": {"type": "integer", "minimum": 1},
},
"required": ["keys", "frames"],
},
}
# ---------------------------------------------------------------------------
# Examples
# ---------------------------------------------------------------------------
EXAMPLE_DIR = SANA_DIR / "asset" / "sana_wm"
def _example_queue(dsl: str) -> list[dict]:
out = []
for seg in dsl.split(","):
keys, frames = seg.rsplit("-", 1)
out.append({"keys": keys.lower(), "frames": int(frames)})
return out
def _load_example_prompt(name: str) -> str:
p = EXAMPLE_DIR / f"{name}.txt"
return p.read_text(encoding="utf-8", errors="replace").strip() if p.exists() else ""
EXAMPLES = []
for stem, dsl in [
("demo_0", "w-24,jw-24,w-24"),
("demo_1", "w-32,iw-16,w-24"),
("demo_2", "w-16,lw-16,w-16,jw-16,w-16"),
]:
img_path = EXAMPLE_DIR / f"{stem}.png"
if img_path.exists():
EXAMPLES.append([str(img_path), _load_example_prompt(stem), dsl])
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
DESCRIPTION = """
# 🌍 SANA-WM — Camera-Controlled World Model
Image-to-video generation with 6-DoF camera control using [`Efficient-Large-Model/SANA-WM_bidirectional`](https://huggingface.co/Efficient-Large-Model/SANA-WM_bidirectional) and the [NVlabs/Sana](https://github.com/NVlabs/Sana) Stage-1 pipeline.
"""
with gr.Blocks(theme=gr.themes.Soft(), title="SANA-WM Demo") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(label="First frame", type="pil", height=300)
prompt_in = gr.Textbox(
label="Prompt",
lines=4,
placeholder="A first-person view from a strictly stationary observation point...",
)
gr.Markdown("### 🎮 Camera action queue")
action_queue = ActionQueue(value=[{"keys": "w", "frames": 24}])
dsl_view = gr.Textbox(label="Resulting DSL", interactive=False)
run_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
with gr.Accordion("Advanced", open=False):
num_frames = gr.Slider(
9, 161, value=41, step=8,
label="num_frames (auto-tracks queue length, LTX-2 VAE → 8k+1)",
info="Adjust to override the auto-sync; will be snapped to 8k+1.",
)
fps = gr.Slider(8, 24, value=16, step=1, label="fps")
steps = gr.Slider(10, 80, value=40, step=2, label="DiT sampling steps")
cfg_scale = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="cfg_scale")
translation_speed = gr.Slider(
0.01, 0.15, value=DEFAULT_TRANSLATION_SPEED, step=0.005, label="translation_speed (per frame)"
)
rotation_speed_deg = gr.Slider(
0.2, 5.0, value=DEFAULT_ROTATION_SPEED_DEG, step=0.1, label="rotation_speed_deg (per frame)"
)
seed = gr.Number(value=42, precision=0, label="seed")
show_overlay = gr.Checkbox(value=True, label="Composite Genie-style action overlay on the output")
with gr.Column(scale=1):
video_out = gr.Video(label="Output (704×1280)", height=520, autoplay=True)
gr.Markdown(
"_The output is the Sana VAE decode of Stage-1 latents (no refiner). "
"For peak quality use the full pipeline with `--no_refiner` disabled offline._"
)
if EXAMPLES:
# action_queue accepts DSL strings (normalized in Python and parsed in
# the component's watch hook), so we can feed the DSL string straight
# in — gives a readable table column. cache_examples=lazy runs infer()
# on first click per row and caches the output mp4.
gr.Examples(
examples=EXAMPLES,
inputs=[image_in, prompt_in, action_queue],
outputs=[video_out, dsl_view],
fn=infer,
cache_examples=True,
cache_mode="lazy",
label="Example image + prompt + camera action queue (lazy-cached)",
)
# Mirror the queue's DSL into the read-only textbox.
action_queue.change(
fn=lambda q: queue_to_dsl(q or []),
inputs=[action_queue],
outputs=[dsl_view],
)
# Auto-sync num_frames to the queue total (snapped to 8k+1, clipped to slider).
def _sync_num_frames(q: list) -> int:
total = queue_total_frames(q or [])
if total < 1:
return 9
snapped = _snap_num_frames(total + 1, stride=8, upper_bound=total + 1)
return max(9, min(161, snapped))
action_queue.change(
fn=_sync_num_frames,
inputs=[action_queue],
outputs=[num_frames],
)
run_btn.click(
fn=infer,
inputs=[
image_in,
prompt_in,
action_queue,
num_frames,
fps,
steps,
cfg_scale,
translation_speed,
rotation_speed_deg,
seed,
show_overlay,
],
outputs=[video_out, dsl_view],
)
if __name__ == "__main__":
head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
demo.launch(head=head)