Spaces:
Running on Zero
Running on Zero
| """ZeroGPU Gradio demo for SANA-WM (bidirectional). | |
| Loads ``Efficient-Large-Model/SANA-WM_bidirectional`` (Stage-1 DiT + LTX-2 VAE | |
| + gemma-2-2b-it text encoder, no refiner) and lets the user roll out a | |
| camera-controlled image-to-video clip from a WASD/IJKL action queue. | |
| """ | |
| from __future__ import annotations | |
| # ``spaces`` must be imported BEFORE anything that initialises CUDA (including | |
| # torch, mmcv, etc.). mmcv pulls torch in transitively, so even our runtime | |
| # install path below would touch CUDA emulation first if ``spaces`` weren't | |
| # already imported. | |
| import spaces # noqa: E402,F401 | |
| import os | |
| import sys | |
| import time | |
| import tempfile | |
| from pathlib import Path | |
| # Vendor the Sana PR branch as a sibling directory. We bundle the | |
| # feat/sana-wm branch into ./Sana when deploying to a Space. | |
| ROOT = Path(__file__).resolve().parent | |
| SANA_DIR = ROOT / "Sana" | |
| if str(SANA_DIR) not in sys.path: | |
| sys.path.insert(0, str(SANA_DIR)) | |
| # Must be set before any Sana / xformers import (see inference_sana_wm.py). | |
| os.environ.setdefault("DISABLE_XFORMERS", "1") | |
| # mmcv 1.7.2's setup.py imports ``pkg_resources`` (removed in setuptools 80+), | |
| # so it cannot be installed via Spaces' isolated pip build env. We follow the | |
| # Sana ``environment_setup.sh`` recipe and install it here on first cold-start: | |
| # pin setuptools<80, then build mmcv with ``--no-build-isolation`` so the | |
| # build sees the env's setuptools. Subsequent imports are no-ops. | |
| def _ensure_mmcv() -> None: | |
| try: | |
| import mmcv # noqa: F401 | |
| return | |
| except ImportError: | |
| pass | |
| import subprocess | |
| import sys as _sys | |
| print("[startup] mmcv not found — installing with --no-build-isolation...", flush=True) | |
| subprocess.check_call( | |
| [_sys.executable, "-m", "pip", "install", "--quiet", "setuptools<80", "wheel"] | |
| ) | |
| subprocess.check_call( | |
| [ | |
| _sys.executable, "-m", "pip", "install", "--quiet", | |
| "--no-build-isolation", "mmcv==1.7.2", | |
| ] | |
| ) | |
| import mmcv # noqa: F401 | |
| print("[startup] mmcv installed.", flush=True) | |
| _ensure_mmcv() | |
| # flash-attn matches Sana's ``environment_setup.sh`` pin (``flash-attn>=2.7.0``). | |
| # ``FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE`` skips the nvcc kernel build, so the | |
| # wheel installs in seconds — but ``flash_attn/__init__.py`` then unconditionally | |
| # does ``import flash_attn_2_cuda``, which doesn't exist. We pre-stub that | |
| # CUDA-extension module in ``sys.modules`` so the import succeeds; SANA-WM uses | |
| # Triton GDN paths and never reaches the actual CUDA kernels. | |
| def _ensure_flash_attn() -> None: | |
| import types as _types | |
| import sys as _sys | |
| if "flash_attn_2_cuda" not in _sys.modules: | |
| _sys.modules["flash_attn_2_cuda"] = _types.ModuleType("flash_attn_2_cuda") | |
| try: | |
| import flash_attn # noqa: F401 | |
| return | |
| except ImportError: | |
| pass | |
| import subprocess | |
| print("[startup] flash-attn not found — installing (CUDA build skipped)...", flush=True) | |
| env = dict(os.environ) | |
| env["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE" | |
| subprocess.check_call( | |
| [ | |
| _sys.executable, "-m", "pip", "install", "--quiet", | |
| "--no-build-isolation", "flash-attn>=2.7.0", | |
| ], | |
| env=env, | |
| ) | |
| if "flash_attn" in _sys.modules: | |
| del _sys.modules["flash_attn"] | |
| import flash_attn # noqa: F401 | |
| print("[startup] flash-attn installed.", flush=True) | |
| _ensure_flash_attn() | |
| import gradio as gr | |
| import imageio.v3 as iio | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| # Sana-WM imports — registered by importing nets, then the pipeline module. | |
| import diffusion.model.nets # noqa: F401 | |
| from inference_video_scripts.inference_sana_wm import ( | |
| DEFAULT_ROTATION_SPEED_DEG, | |
| DEFAULT_TRANSLATION_SPEED, | |
| GenerationParams, | |
| InferenceConfig, | |
| SanaWMPipeline, | |
| TARGET_HEIGHT, | |
| TARGET_WIDTH, | |
| action_string_to_c2w, | |
| estimate_intrinsics_with_pi3x, | |
| resize_and_center_crop, | |
| transform_intrinsics_for_crop, | |
| ) | |
| from diffusion.utils.action_overlay import apply_overlay | |
| from diffusion.utils.logger import get_root_logger | |
| import pyrallis | |
| from sana.tools import resolve_hf_path | |
| HF_REPO = "Efficient-Large-Model/SANA-WM_bidirectional" | |
| CONFIG_URI = f"hf://{HF_REPO}/config.yaml" | |
| MODEL_URI = f"hf://{HF_REPO}/dit/sana_wm_1600m_720p.safetensors" | |
| LOGGER = get_root_logger() | |
| # --------------------------------------------------------------------------- | |
| # Pipeline — built at module load time. ZeroGPU's CUDA emulation lets us place | |
| # tensors on "cuda" without a real GPU; @spaces.GPU swaps in a real device for | |
| # the duration of the decorated call. Lazy-loading inside the decorator is | |
| # strongly discouraged on ZeroGPU. | |
| # --------------------------------------------------------------------------- | |
| print("[startup] Building SanaWMPipeline (downloads ~17 GB on first launch)...", flush=True) | |
| _t0 = time.time() | |
| PIPELINE_CONFIG: InferenceConfig = pyrallis.parse( | |
| config_class=InferenceConfig, config_path=resolve_hf_path(CONFIG_URI), args=[] | |
| ) | |
| PIPELINE = SanaWMPipeline( | |
| config=PIPELINE_CONFIG, | |
| model_path=resolve_hf_path(MODEL_URI), | |
| device="cuda", | |
| refiner=None, | |
| logger=LOGGER, | |
| ) | |
| print(f"[startup] Pipeline ready in {time.time() - _t0:.1f}s.", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # Action-queue helpers (queue list <-> DSL string) | |
| # --------------------------------------------------------------------------- | |
| ALLOWED_KEYS = set("wasdijkl") | |
| def _normalize_queue(queue) -> list[dict]: | |
| """Accept either a list[dict] (from the UI), a DSL string (from examples / | |
| serialized state), or anything falsy. Always returns a list of dicts.""" | |
| if not queue: | |
| return [] | |
| if isinstance(queue, str): | |
| out = [] | |
| for seg in queue.split(","): | |
| seg = seg.strip() | |
| if not seg or "-" not in seg: | |
| continue | |
| keys, frames = seg.rsplit("-", 1) | |
| try: | |
| out.append({"keys": keys.lower(), "frames": max(1, int(frames))}) | |
| except ValueError: | |
| continue | |
| return out | |
| if isinstance(queue, list): | |
| return [s for s in queue if isinstance(s, dict)] | |
| return [] | |
| def queue_to_dsl(queue) -> str: | |
| """[{keys: 'wj', frames: 16}, ...] -> 'wj-16,...' — also accepts a DSL str.""" | |
| queue = _normalize_queue(queue) | |
| if not queue: | |
| return "" | |
| parts = [] | |
| for seg in queue: | |
| keys = (seg.get("keys") or "").lower().strip() | |
| keys = "".join(sorted(set(c for c in keys if c in ALLOWED_KEYS))) or "none" | |
| frames = max(1, int(seg.get("frames", 1))) | |
| parts.append(f"{keys}-{frames}") | |
| return ",".join(parts) | |
| def queue_total_frames(queue) -> int: | |
| queue = _normalize_queue(queue) | |
| return sum(max(1, int(s.get("frames", 1))) for s in queue) | |
| def _snap_num_frames(n: int, stride: int = 8, upper_bound: int | None = None) -> int: | |
| """LTX-2 VAE wants num_frames = 8*k + 1.""" | |
| if n < 1: | |
| return 1 | |
| if (n - 1) % stride == 0: | |
| return n | |
| floor_cand = n - ((n - 1) % stride) | |
| ceil_cand = floor_cand + stride | |
| snapped = floor_cand if (n - floor_cand) < (ceil_cand - n) else ceil_cand | |
| if upper_bound is not None and snapped > upper_bound: | |
| snapped = floor_cand | |
| return max(snapped, 1) | |
| # --------------------------------------------------------------------------- | |
| # tqdm <-> gr.Progress bridge | |
| # | |
| # Sana's samplers do ``from tqdm import tqdm`` at module load, so the symbol | |
| # binds before our patch runs. We swap that *bound name* on the sampler module | |
| # with a subclass that forwards progress on every update, then restore on exit. | |
| # Avoids ``track_tqdm=True``, which breaks inside ZeroGPU subprocesses. | |
| # --------------------------------------------------------------------------- | |
| import contextlib | |
| import importlib | |
| _SANA_TQDM_TARGETS = [ | |
| "diffusion.scheduler.flow_euler_sampler", | |
| "diffusion.scheduler.dpm_solver", | |
| "diffusion.model.dpm_solver", | |
| ] | |
| def _bridge_sana_tqdm(progress: "gr.Progress", *, span=(0.4, 0.9), desc_prefix="DiT"): | |
| """Forward Sana's tqdm ticks to the Gradio progress bar.""" | |
| import tqdm as _tqdm_mod | |
| lo, hi = span | |
| class _BridgedTqdm(_tqdm_mod.tqdm): | |
| def update(self, n=1): | |
| ret = super().update(n) | |
| try: | |
| if self.total: | |
| frac = max(0.0, min(1.0, self.n / self.total)) | |
| progress(lo + (hi - lo) * frac, desc=f"{desc_prefix} step {self.n}/{self.total}") | |
| except Exception: | |
| pass | |
| return ret | |
| saved: dict[str, object] = {} | |
| for mod_name in _SANA_TQDM_TARGETS: | |
| try: | |
| mod = importlib.import_module(mod_name) | |
| except Exception: | |
| continue | |
| if hasattr(mod, "tqdm"): | |
| saved[mod_name] = mod.tqdm | |
| mod.tqdm = _BridgedTqdm | |
| try: | |
| yield | |
| finally: | |
| for mod_name, original in saved.items(): | |
| try: | |
| importlib.import_module(mod_name).tqdm = original | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Inference | |
| # --------------------------------------------------------------------------- | |
| def infer( | |
| image: Image.Image, | |
| prompt: str, | |
| queue_state, | |
| num_frames: int = 161, | |
| fps: int = 16, | |
| steps: int = 40, | |
| cfg_scale: float = 5.0, | |
| translation_speed: float = DEFAULT_TRANSLATION_SPEED, | |
| rotation_speed_deg: float = DEFAULT_ROTATION_SPEED_DEG, | |
| seed: int = 42, | |
| show_overlay: bool = True, | |
| progress: gr.Progress = gr.Progress(), | |
| ): | |
| if image is None: | |
| raise gr.Error("Please upload a first-frame image.") | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please enter a prompt.") | |
| dsl = queue_to_dsl(queue_state or []) | |
| if not dsl: | |
| raise gr.Error("Action queue is empty — add at least one segment.") | |
| progress(0.05, desc="Rolling out trajectory") | |
| c2w_full = action_string_to_c2w( | |
| dsl, | |
| translation_speed=translation_speed, | |
| rotation_speed_deg=rotation_speed_deg, | |
| ) | |
| upper = c2w_full.shape[0] | |
| snapped = _snap_num_frames(min(int(num_frames), upper), stride=8, upper_bound=upper) | |
| if snapped < 9: | |
| raise gr.Error( | |
| f"Need at least 9 frames after LTX-2 snapping (8k+1); your queue rolled out {upper} frames." | |
| ) | |
| c2w = c2w_full[:snapped] | |
| progress(0.1, desc="Preparing image") | |
| cropped, src_size, resized_size, crop_offset = resize_and_center_crop(image) | |
| progress(0.15, desc="Estimating intrinsics (Pi3X)") | |
| device = torch.device("cuda") | |
| intr_one = estimate_intrinsics_with_pi3x(image, device, LOGGER) | |
| intr_src = np.broadcast_to(intr_one, (snapped, 4)).copy() | |
| intrinsics_vec4 = transform_intrinsics_for_crop(intr_src, src_size, resized_size, crop_offset) | |
| params = GenerationParams( | |
| num_frames=snapped, | |
| fps=int(fps), | |
| step=int(steps), | |
| cfg_scale=float(cfg_scale), | |
| seed=int(seed), | |
| sampling_algo="flow_euler_ltx", | |
| ) | |
| progress(0.4, desc=f"Sampling {snapped} frames in {steps} steps") | |
| t0 = time.time() | |
| with _bridge_sana_tqdm(progress, span=(0.4, 0.9), desc_prefix="DiT"): | |
| out = PIPELINE.generate(cropped, prompt.strip(), c2w, intrinsics_vec4, params) | |
| LOGGER.info(f"Sampling done in {time.time() - t0:.1f}s") | |
| video_hwc = out["video"] | |
| if show_overlay: | |
| progress(0.92, desc="Compositing action overlay") | |
| video_hwc = apply_overlay(video_hwc, out["c2w"]) | |
| out_path = Path(tempfile.mkdtemp()) / "sana_wm.mp4" | |
| iio.imwrite(out_path, video_hwc, fps=params.fps) | |
| return str(out_path), dsl | |
| # --------------------------------------------------------------------------- | |
| # ActionQueue: custom gr.HTML component | |
| # - 8 directional buttons (WASD = translate, IJKL = rotate) + "idle" | |
| # - Frames-per-tap input | |
| # - Removable queue chips | |
| # - Three.js top-down trajectory preview that mirrors action_string_to_c2w | |
| # --------------------------------------------------------------------------- | |
| HTML_TEMPLATE = """ | |
| <div class="aq-wrap"> | |
| <div class="aq-pads"> | |
| <div class="aq-pad"> | |
| <div class="aq-pad-label">Translate (WASD)</div> | |
| <div class="aq-grid"> | |
| <div></div><button class="aq-btn aq-tr" data-k="w">W</button><div></div> | |
| <button class="aq-btn aq-tr" data-k="a">A</button> | |
| <button class="aq-btn aq-idle" data-k="none">·</button> | |
| <button class="aq-btn aq-tr" data-k="d">D</button> | |
| <div></div><button class="aq-btn aq-tr" data-k="s">S</button><div></div> | |
| </div> | |
| </div> | |
| <div class="aq-pad"> | |
| <div class="aq-pad-label">Rotate (IJKL)</div> | |
| <div class="aq-grid"> | |
| <div></div><button class="aq-btn aq-ro" data-k="i">I ↑</button><div></div> | |
| <button class="aq-btn aq-ro" data-k="j">J ←</button> | |
| <div></div> | |
| <button class="aq-btn aq-ro" data-k="l">L →</button> | |
| <div></div><button class="aq-btn aq-ro" data-k="k">K ↓</button><div></div> | |
| </div> | |
| </div> | |
| </div> | |
| <div class="aq-row"> | |
| <label>Frames per tap | |
| <input type="number" id="aq-frames" min="1" max="320" step="1" value="16" /> | |
| </label> | |
| <span class="aq-spacer"></span> | |
| <button class="aq-btn aq-secondary" id="aq-back">⌫ Back</button> | |
| <button class="aq-btn aq-danger" id="aq-clear">Clear</button> | |
| </div> | |
| <div class="aq-queue-box"> | |
| <div class="aq-queue-header"> | |
| <span>Queue</span><span id="aq-total"></span> | |
| </div> | |
| <div id="aq-queue" class="aq-queue"></div> | |
| <div id="aq-dsl" class="aq-dsl"></div> | |
| </div> | |
| </div> | |
| """ | |
| CSS_TEMPLATE = """ | |
| .aq-wrap { display:flex; flex-direction:column; gap:10px; font-size:13px; } | |
| .aq-pads { display:flex; gap:14px; } | |
| .aq-pad { flex:1; background:#161b22; border:1px solid #30363d; border-radius:10px; padding:10px; } | |
| .aq-pad-label { color:#8b949e; font-size:11px; text-transform:uppercase; letter-spacing:0.05em; margin-bottom:6px; text-align:center; } | |
| .aq-grid { display:grid; grid-template-columns: repeat(3, 1fr); gap:6px; } | |
| .aq-btn { | |
| background:#21262d; color:#e6edf3; border:1px solid #30363d; border-radius:8px; | |
| padding:10px 0; cursor:pointer; font-weight:600; font-family:inherit; font-size:13px; | |
| transition:transform .08s ease, background .12s ease; | |
| } | |
| .aq-btn:hover { background:#2d333b; } | |
| .aq-btn:active { transform:scale(0.96); } | |
| .aq-tr { background:#1f6feb22; border-color:#1f6feb55; color:#79b8ff; } | |
| .aq-tr:hover { background:#1f6feb44; } | |
| .aq-ro { background:#a371f722; border-color:#a371f755; color:#c9a4ff; } | |
| .aq-ro:hover { background:#a371f744; } | |
| .aq-idle { background:#30363d; color:#8b949e; } | |
| .aq-secondary { background:#30363d; } | |
| .aq-danger { background:#da363322; border-color:#da363355; color:#ff7b72; } | |
| .aq-row { display:flex; align-items:center; gap:10px; } | |
| .aq-row label { color:#8b949e; font-size:12px; display:flex; align-items:center; gap:6px; } | |
| .aq-row input[type=number] { | |
| background:#0d1117; color:#e6edf3; border:1px solid #30363d; border-radius:6px; | |
| padding:4px 6px; width:70px; font-family:inherit; | |
| } | |
| .aq-spacer { flex:1; } | |
| .aq-queue-box { background:#0d1117; border:1px solid #30363d; border-radius:10px; padding:10px; } | |
| .aq-queue-header { display:flex; justify-content:space-between; color:#8b949e; font-size:11px; text-transform:uppercase; margin-bottom:6px; } | |
| .aq-queue { display:flex; flex-wrap:wrap; gap:6px; min-height:32px; } | |
| .aq-chip { | |
| display:inline-flex; align-items:center; gap:6px; padding:4px 8px; border-radius:14px; | |
| background:#21262d; border:1px solid #30363d; color:#e6edf3; font-family:ui-monospace,monospace; font-size:12px; | |
| } | |
| .aq-chip button { background:transparent; border:none; color:#8b949e; cursor:pointer; font-size:14px; line-height:1; padding:0; } | |
| .aq-chip button:hover { color:#ff7b72; } | |
| .aq-empty { color:#6e7681; font-style:italic; padding:4px 0; } | |
| .aq-dsl { margin-top:8px; padding:6px 8px; background:#010409; border-radius:6px; color:#7ee787; font-family:ui-monospace,monospace; font-size:12px; word-break:break-all; min-height:18px; } | |
| """ | |
| JS_ON_LOAD = r""" | |
| (() => { | |
| const root = element; | |
| const framesInput = root.querySelector('#aq-frames'); | |
| const backBtn = root.querySelector('#aq-back'); | |
| const clearBtn = root.querySelector('#aq-clear'); | |
| const queueEl = root.querySelector('#aq-queue'); | |
| const totalEl = root.querySelector('#aq-total'); | |
| const dslEl = root.querySelector('#aq-dsl'); | |
| // ----- queue state ------------------------------------------------------ | |
| let queue = Array.isArray(props.value) ? props.value.slice() : []; | |
| const MAX_FRAMES = 160; // num_frames slider max is 161 = MAX_FRAMES + 1 | |
| function render() { | |
| queueEl.innerHTML = ''; | |
| if (queue.length === 0) { | |
| const e = document.createElement('span'); | |
| e.className = 'aq-empty'; | |
| e.textContent = 'No segments yet — tap a button above.'; | |
| queueEl.appendChild(e); | |
| } else { | |
| queue.forEach((seg, i) => { | |
| const chip = document.createElement('span'); | |
| chip.className = 'aq-chip'; | |
| chip.textContent = `${seg.keys || 'none'}-${seg.frames}`; | |
| const x = document.createElement('button'); | |
| x.textContent = '×'; | |
| x.title = 'Remove'; | |
| x.addEventListener('click', (ev) => { | |
| ev.stopPropagation(); | |
| queue.splice(i, 1); | |
| commit(); | |
| }); | |
| chip.appendChild(x); | |
| queueEl.appendChild(chip); | |
| }); | |
| } | |
| const total = queue.reduce((s, x) => s + (x.frames | 0), 0); | |
| totalEl.textContent = total | |
| ? `${queue.length} segment${queue.length > 1 ? 's' : ''} · ${total}/${MAX_FRAMES} frames` | |
| : `0/${MAX_FRAMES} frames`; | |
| totalEl.style.color = total >= MAX_FRAMES ? '#ff7b72' : ''; | |
| dslEl.textContent = queue.length | |
| ? queue.map(s => `${(s.keys || 'none')}-${s.frames}`).join(',') | |
| : '(empty)'; | |
| } | |
| function commit() { | |
| render(); | |
| props.value = queue.slice(); | |
| trigger('change', queue.slice()); | |
| } | |
| function appendSegment(keys) { | |
| const requested = Math.max(1, parseInt(framesInput.value || '16', 10)); | |
| const current = queue.reduce((s, x) => s + (x.frames | 0), 0); | |
| const remaining = MAX_FRAMES - current; | |
| if (remaining <= 0) return; // queue full | |
| const frames = Math.min(requested, remaining); | |
| // Coalesce consecutive identical segments. | |
| const last = queue[queue.length - 1]; | |
| if (last && last.keys === keys) { | |
| last.frames += frames; | |
| } else { | |
| queue.push({ keys, frames }); | |
| } | |
| commit(); | |
| } | |
| root.querySelectorAll('.aq-btn[data-k]').forEach(btn => { | |
| btn.addEventListener('click', () => appendSegment(btn.dataset.k)); | |
| }); | |
| backBtn.addEventListener('click', () => { queue.pop(); commit(); }); | |
| clearBtn.addEventListener('click', () => { queue = []; commit(); }); | |
| // Keyboard shortcuts (when focus is anywhere in the queue block). | |
| root.tabIndex = 0; | |
| root.addEventListener('keydown', (e) => { | |
| const k = e.key.toLowerCase(); | |
| if ('wasdijkl'.includes(k)) { appendSegment(k); e.preventDefault(); } | |
| else if (k === ' ') { appendSegment('none'); e.preventDefault(); } | |
| else if (k === 'backspace') { queue.pop(); commit(); e.preventDefault(); } | |
| }); | |
| // React to external value changes (Examples click, programmatic updates). | |
| function parseDsl(s) { | |
| return s.split(',').map(seg => { | |
| const idx = seg.lastIndexOf('-'); | |
| if (idx < 0) return null; | |
| const keys = seg.slice(0, idx).toLowerCase().trim(); | |
| const frames = parseInt(seg.slice(idx + 1), 10); | |
| if (!keys || !Number.isFinite(frames) || frames < 1) return null; | |
| return { keys, frames }; | |
| }).filter(Boolean); | |
| } | |
| watch('value', () => { | |
| let incoming; | |
| if (Array.isArray(props.value)) { | |
| incoming = props.value; | |
| } else if (typeof props.value === 'string') { | |
| incoming = parseDsl(props.value); | |
| } else { | |
| incoming = []; | |
| } | |
| // Skip if the incoming value matches our local state — avoids feedback loops | |
| // from our own commit() calls. | |
| if (JSON.stringify(incoming) === JSON.stringify(queue)) return; | |
| queue = incoming.map(s => ({ keys: (s.keys || 'none'), frames: Math.max(1, s.frames | 0) })); | |
| render(); | |
| }); | |
| render(); | |
| })(); | |
| """ | |
| class ActionQueue(gr.HTML): | |
| """WASD/IJKL action-queue input. value = list[{keys: str, frames: int}].""" | |
| def __init__(self, value=None, **kwargs): | |
| if value is None: | |
| value = [] | |
| super().__init__( | |
| value=value, | |
| html_template=HTML_TEMPLATE, | |
| css_template=CSS_TEMPLATE, | |
| js_on_load=JS_ON_LOAD, | |
| **kwargs, | |
| ) | |
| def api_info(self): | |
| return { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "properties": { | |
| "keys": {"type": "string"}, | |
| "frames": {"type": "integer", "minimum": 1}, | |
| }, | |
| "required": ["keys", "frames"], | |
| }, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Examples | |
| # --------------------------------------------------------------------------- | |
| EXAMPLE_DIR = SANA_DIR / "asset" / "sana_wm" | |
| def _example_queue(dsl: str) -> list[dict]: | |
| out = [] | |
| for seg in dsl.split(","): | |
| keys, frames = seg.rsplit("-", 1) | |
| out.append({"keys": keys.lower(), "frames": int(frames)}) | |
| return out | |
| def _load_example_prompt(name: str) -> str: | |
| p = EXAMPLE_DIR / f"{name}.txt" | |
| return p.read_text(encoding="utf-8", errors="replace").strip() if p.exists() else "" | |
| EXAMPLES = [] | |
| for stem, dsl in [ | |
| ("demo_0", "w-24,jw-24,w-24"), | |
| ("demo_1", "w-32,iw-16,w-24"), | |
| ("demo_2", "w-16,lw-16,w-16,jw-16,w-16"), | |
| ]: | |
| img_path = EXAMPLE_DIR / f"{stem}.png" | |
| if img_path.exists(): | |
| EXAMPLES.append([str(img_path), _load_example_prompt(stem), dsl]) | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| DESCRIPTION = """ | |
| # 🌍 SANA-WM — Camera-Controlled World Model | |
| Image-to-video generation with 6-DoF camera control using [`Efficient-Large-Model/SANA-WM_bidirectional`](https://huggingface.co/Efficient-Large-Model/SANA-WM_bidirectional) and the [NVlabs/Sana](https://github.com/NVlabs/Sana) Stage-1 pipeline. | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="SANA-WM Demo") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_in = gr.Image(label="First frame", type="pil", height=300) | |
| prompt_in = gr.Textbox( | |
| label="Prompt", | |
| lines=4, | |
| placeholder="A first-person view from a strictly stationary observation point...", | |
| ) | |
| gr.Markdown("### 🎮 Camera action queue") | |
| action_queue = ActionQueue(value=[{"keys": "w", "frames": 24}]) | |
| dsl_view = gr.Textbox(label="Resulting DSL", interactive=False) | |
| run_btn = gr.Button("🚀 Generate", variant="primary", size="lg") | |
| with gr.Accordion("Advanced", open=False): | |
| num_frames = gr.Slider( | |
| 9, 161, value=41, step=8, | |
| label="num_frames (auto-tracks queue length, LTX-2 VAE → 8k+1)", | |
| info="Adjust to override the auto-sync; will be snapped to 8k+1.", | |
| ) | |
| fps = gr.Slider(8, 24, value=16, step=1, label="fps") | |
| steps = gr.Slider(10, 80, value=40, step=2, label="DiT sampling steps") | |
| cfg_scale = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="cfg_scale") | |
| translation_speed = gr.Slider( | |
| 0.01, 0.15, value=DEFAULT_TRANSLATION_SPEED, step=0.005, label="translation_speed (per frame)" | |
| ) | |
| rotation_speed_deg = gr.Slider( | |
| 0.2, 5.0, value=DEFAULT_ROTATION_SPEED_DEG, step=0.1, label="rotation_speed_deg (per frame)" | |
| ) | |
| seed = gr.Number(value=42, precision=0, label="seed") | |
| show_overlay = gr.Checkbox(value=True, label="Composite Genie-style action overlay on the output") | |
| with gr.Column(scale=1): | |
| video_out = gr.Video(label="Output (704×1280)", height=520, autoplay=True) | |
| gr.Markdown( | |
| "_The output is the Sana VAE decode of Stage-1 latents (no refiner). " | |
| "For peak quality use the full pipeline with `--no_refiner` disabled offline._" | |
| ) | |
| if EXAMPLES: | |
| # action_queue accepts DSL strings (normalized in Python and parsed in | |
| # the component's watch hook), so we can feed the DSL string straight | |
| # in — gives a readable table column. cache_examples=lazy runs infer() | |
| # on first click per row and caches the output mp4. | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[image_in, prompt_in, action_queue], | |
| outputs=[video_out, dsl_view], | |
| fn=infer, | |
| cache_examples=True, | |
| cache_mode="lazy", | |
| label="Example image + prompt + camera action queue (lazy-cached)", | |
| ) | |
| # Mirror the queue's DSL into the read-only textbox. | |
| action_queue.change( | |
| fn=lambda q: queue_to_dsl(q or []), | |
| inputs=[action_queue], | |
| outputs=[dsl_view], | |
| ) | |
| # Auto-sync num_frames to the queue total (snapped to 8k+1, clipped to slider). | |
| def _sync_num_frames(q: list) -> int: | |
| total = queue_total_frames(q or []) | |
| if total < 1: | |
| return 9 | |
| snapped = _snap_num_frames(total + 1, stride=8, upper_bound=total + 1) | |
| return max(9, min(161, snapped)) | |
| action_queue.change( | |
| fn=_sync_num_frames, | |
| inputs=[action_queue], | |
| outputs=[num_frames], | |
| ) | |
| run_btn.click( | |
| fn=infer, | |
| inputs=[ | |
| image_in, | |
| prompt_in, | |
| action_queue, | |
| num_frames, | |
| fps, | |
| steps, | |
| cfg_scale, | |
| translation_speed, | |
| rotation_speed_deg, | |
| seed, | |
| show_overlay, | |
| ], | |
| outputs=[video_out, dsl_view], | |
| ) | |
| if __name__ == "__main__": | |
| head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>' | |
| demo.launch(head=head) | |