"""ZeroGPU Gradio demo for SANA-WM (bidirectional). Loads ``Efficient-Large-Model/SANA-WM_bidirectional`` (Stage-1 DiT + LTX-2 VAE + gemma-2-2b-it text encoder, no refiner) and lets the user roll out a camera-controlled image-to-video clip from a WASD/IJKL action queue. """ from __future__ import annotations # ``spaces`` must be imported BEFORE anything that initialises CUDA (including # torch, mmcv, etc.). mmcv pulls torch in transitively, so even our runtime # install path below would touch CUDA emulation first if ``spaces`` weren't # already imported. import spaces # noqa: E402,F401 import os import sys import time import tempfile from pathlib import Path # Vendor the Sana PR branch as a sibling directory. We bundle the # feat/sana-wm branch into ./Sana when deploying to a Space. ROOT = Path(__file__).resolve().parent SANA_DIR = ROOT / "Sana" if str(SANA_DIR) not in sys.path: sys.path.insert(0, str(SANA_DIR)) # Must be set before any Sana / xformers import (see inference_sana_wm.py). os.environ.setdefault("DISABLE_XFORMERS", "1") # mmcv 1.7.2's setup.py imports ``pkg_resources`` (removed in setuptools 80+), # so it cannot be installed via Spaces' isolated pip build env. We follow the # Sana ``environment_setup.sh`` recipe and install it here on first cold-start: # pin setuptools<80, then build mmcv with ``--no-build-isolation`` so the # build sees the env's setuptools. Subsequent imports are no-ops. def _ensure_mmcv() -> None: try: import mmcv # noqa: F401 return except ImportError: pass import subprocess import sys as _sys print("[startup] mmcv not found — installing with --no-build-isolation...", flush=True) subprocess.check_call( [_sys.executable, "-m", "pip", "install", "--quiet", "setuptools<80", "wheel"] ) subprocess.check_call( [ _sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation", "mmcv==1.7.2", ] ) import mmcv # noqa: F401 print("[startup] mmcv installed.", flush=True) _ensure_mmcv() # flash-attn matches Sana's ``environment_setup.sh`` pin (``flash-attn>=2.7.0``). # ``FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE`` skips the nvcc kernel build, so the # wheel installs in seconds — but ``flash_attn/__init__.py`` then unconditionally # does ``import flash_attn_2_cuda``, which doesn't exist. We pre-stub that # CUDA-extension module in ``sys.modules`` so the import succeeds; SANA-WM uses # Triton GDN paths and never reaches the actual CUDA kernels. def _ensure_flash_attn() -> None: import types as _types import sys as _sys if "flash_attn_2_cuda" not in _sys.modules: _sys.modules["flash_attn_2_cuda"] = _types.ModuleType("flash_attn_2_cuda") try: import flash_attn # noqa: F401 return except ImportError: pass import subprocess print("[startup] flash-attn not found — installing (CUDA build skipped)...", flush=True) env = dict(os.environ) env["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE" subprocess.check_call( [ _sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation", "flash-attn>=2.7.0", ], env=env, ) if "flash_attn" in _sys.modules: del _sys.modules["flash_attn"] import flash_attn # noqa: F401 print("[startup] flash-attn installed.", flush=True) _ensure_flash_attn() import gradio as gr import imageio.v3 as iio import numpy as np import torch from PIL import Image # Sana-WM imports — registered by importing nets, then the pipeline module. import diffusion.model.nets # noqa: F401 from inference_video_scripts.inference_sana_wm import ( DEFAULT_ROTATION_SPEED_DEG, DEFAULT_TRANSLATION_SPEED, GenerationParams, InferenceConfig, SanaWMPipeline, TARGET_HEIGHT, TARGET_WIDTH, action_string_to_c2w, estimate_intrinsics_with_pi3x, resize_and_center_crop, transform_intrinsics_for_crop, ) from diffusion.utils.action_overlay import apply_overlay from diffusion.utils.logger import get_root_logger import pyrallis from sana.tools import resolve_hf_path HF_REPO = "Efficient-Large-Model/SANA-WM_bidirectional" CONFIG_URI = f"hf://{HF_REPO}/config.yaml" MODEL_URI = f"hf://{HF_REPO}/dit/sana_wm_1600m_720p.safetensors" LOGGER = get_root_logger() # --------------------------------------------------------------------------- # Pipeline — built at module load time. ZeroGPU's CUDA emulation lets us place # tensors on "cuda" without a real GPU; @spaces.GPU swaps in a real device for # the duration of the decorated call. Lazy-loading inside the decorator is # strongly discouraged on ZeroGPU. # --------------------------------------------------------------------------- print("[startup] Building SanaWMPipeline (downloads ~17 GB on first launch)...", flush=True) _t0 = time.time() PIPELINE_CONFIG: InferenceConfig = pyrallis.parse( config_class=InferenceConfig, config_path=resolve_hf_path(CONFIG_URI), args=[] ) PIPELINE = SanaWMPipeline( config=PIPELINE_CONFIG, model_path=resolve_hf_path(MODEL_URI), device="cuda", refiner=None, logger=LOGGER, ) print(f"[startup] Pipeline ready in {time.time() - _t0:.1f}s.", flush=True) # --------------------------------------------------------------------------- # Action-queue helpers (queue list <-> DSL string) # --------------------------------------------------------------------------- ALLOWED_KEYS = set("wasdijkl") def _normalize_queue(queue) -> list[dict]: """Accept either a list[dict] (from the UI), a DSL string (from examples / serialized state), or anything falsy. Always returns a list of dicts.""" if not queue: return [] if isinstance(queue, str): out = [] for seg in queue.split(","): seg = seg.strip() if not seg or "-" not in seg: continue keys, frames = seg.rsplit("-", 1) try: out.append({"keys": keys.lower(), "frames": max(1, int(frames))}) except ValueError: continue return out if isinstance(queue, list): return [s for s in queue if isinstance(s, dict)] return [] def queue_to_dsl(queue) -> str: """[{keys: 'wj', frames: 16}, ...] -> 'wj-16,...' — also accepts a DSL str.""" queue = _normalize_queue(queue) if not queue: return "" parts = [] for seg in queue: keys = (seg.get("keys") or "").lower().strip() keys = "".join(sorted(set(c for c in keys if c in ALLOWED_KEYS))) or "none" frames = max(1, int(seg.get("frames", 1))) parts.append(f"{keys}-{frames}") return ",".join(parts) def queue_total_frames(queue) -> int: queue = _normalize_queue(queue) return sum(max(1, int(s.get("frames", 1))) for s in queue) def _snap_num_frames(n: int, stride: int = 8, upper_bound: int | None = None) -> int: """LTX-2 VAE wants num_frames = 8*k + 1.""" if n < 1: return 1 if (n - 1) % stride == 0: return n floor_cand = n - ((n - 1) % stride) ceil_cand = floor_cand + stride snapped = floor_cand if (n - floor_cand) < (ceil_cand - n) else ceil_cand if upper_bound is not None and snapped > upper_bound: snapped = floor_cand return max(snapped, 1) # --------------------------------------------------------------------------- # tqdm <-> gr.Progress bridge # # Sana's samplers do ``from tqdm import tqdm`` at module load, so the symbol # binds before our patch runs. We swap that *bound name* on the sampler module # with a subclass that forwards progress on every update, then restore on exit. # Avoids ``track_tqdm=True``, which breaks inside ZeroGPU subprocesses. # --------------------------------------------------------------------------- import contextlib import importlib _SANA_TQDM_TARGETS = [ "diffusion.scheduler.flow_euler_sampler", "diffusion.scheduler.dpm_solver", "diffusion.model.dpm_solver", ] @contextlib.contextmanager def _bridge_sana_tqdm(progress: "gr.Progress", *, span=(0.4, 0.9), desc_prefix="DiT"): """Forward Sana's tqdm ticks to the Gradio progress bar.""" import tqdm as _tqdm_mod lo, hi = span class _BridgedTqdm(_tqdm_mod.tqdm): def update(self, n=1): ret = super().update(n) try: if self.total: frac = max(0.0, min(1.0, self.n / self.total)) progress(lo + (hi - lo) * frac, desc=f"{desc_prefix} step {self.n}/{self.total}") except Exception: pass return ret saved: dict[str, object] = {} for mod_name in _SANA_TQDM_TARGETS: try: mod = importlib.import_module(mod_name) except Exception: continue if hasattr(mod, "tqdm"): saved[mod_name] = mod.tqdm mod.tqdm = _BridgedTqdm try: yield finally: for mod_name, original in saved.items(): try: importlib.import_module(mod_name).tqdm = original except Exception: pass # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- @spaces.GPU(duration=210) def infer( image: Image.Image, prompt: str, queue_state, num_frames: int = 161, fps: int = 16, steps: int = 40, cfg_scale: float = 5.0, translation_speed: float = DEFAULT_TRANSLATION_SPEED, rotation_speed_deg: float = DEFAULT_ROTATION_SPEED_DEG, seed: int = 42, show_overlay: bool = True, progress: gr.Progress = gr.Progress(), ): if image is None: raise gr.Error("Please upload a first-frame image.") if not prompt or not prompt.strip(): raise gr.Error("Please enter a prompt.") dsl = queue_to_dsl(queue_state or []) if not dsl: raise gr.Error("Action queue is empty — add at least one segment.") progress(0.05, desc="Rolling out trajectory") c2w_full = action_string_to_c2w( dsl, translation_speed=translation_speed, rotation_speed_deg=rotation_speed_deg, ) upper = c2w_full.shape[0] snapped = _snap_num_frames(min(int(num_frames), upper), stride=8, upper_bound=upper) if snapped < 9: raise gr.Error( f"Need at least 9 frames after LTX-2 snapping (8k+1); your queue rolled out {upper} frames." ) c2w = c2w_full[:snapped] progress(0.1, desc="Preparing image") cropped, src_size, resized_size, crop_offset = resize_and_center_crop(image) progress(0.15, desc="Estimating intrinsics (Pi3X)") device = torch.device("cuda") intr_one = estimate_intrinsics_with_pi3x(image, device, LOGGER) intr_src = np.broadcast_to(intr_one, (snapped, 4)).copy() intrinsics_vec4 = transform_intrinsics_for_crop(intr_src, src_size, resized_size, crop_offset) params = GenerationParams( num_frames=snapped, fps=int(fps), step=int(steps), cfg_scale=float(cfg_scale), seed=int(seed), sampling_algo="flow_euler_ltx", ) progress(0.4, desc=f"Sampling {snapped} frames in {steps} steps") t0 = time.time() with _bridge_sana_tqdm(progress, span=(0.4, 0.9), desc_prefix="DiT"): out = PIPELINE.generate(cropped, prompt.strip(), c2w, intrinsics_vec4, params) LOGGER.info(f"Sampling done in {time.time() - t0:.1f}s") video_hwc = out["video"] if show_overlay: progress(0.92, desc="Compositing action overlay") video_hwc = apply_overlay(video_hwc, out["c2w"]) out_path = Path(tempfile.mkdtemp()) / "sana_wm.mp4" iio.imwrite(out_path, video_hwc, fps=params.fps) return str(out_path), dsl # --------------------------------------------------------------------------- # ActionQueue: custom gr.HTML component # - 8 directional buttons (WASD = translate, IJKL = rotate) + "idle" # - Frames-per-tap input # - Removable queue chips # - Three.js top-down trajectory preview that mirrors action_string_to_c2w # --------------------------------------------------------------------------- HTML_TEMPLATE = """
Translate (WASD)
Rotate (IJKL)
Queue
""" CSS_TEMPLATE = """ .aq-wrap { display:flex; flex-direction:column; gap:10px; font-size:13px; } .aq-pads { display:flex; gap:14px; } .aq-pad { flex:1; background:#161b22; border:1px solid #30363d; border-radius:10px; padding:10px; } .aq-pad-label { color:#8b949e; font-size:11px; text-transform:uppercase; letter-spacing:0.05em; margin-bottom:6px; text-align:center; } .aq-grid { display:grid; grid-template-columns: repeat(3, 1fr); gap:6px; } .aq-btn { background:#21262d; color:#e6edf3; border:1px solid #30363d; border-radius:8px; padding:10px 0; cursor:pointer; font-weight:600; font-family:inherit; font-size:13px; transition:transform .08s ease, background .12s ease; } .aq-btn:hover { background:#2d333b; } .aq-btn:active { transform:scale(0.96); } .aq-tr { background:#1f6feb22; border-color:#1f6feb55; color:#79b8ff; } .aq-tr:hover { background:#1f6feb44; } .aq-ro { background:#a371f722; border-color:#a371f755; color:#c9a4ff; } .aq-ro:hover { background:#a371f744; } .aq-idle { background:#30363d; color:#8b949e; } .aq-secondary { background:#30363d; } .aq-danger { background:#da363322; border-color:#da363355; color:#ff7b72; } .aq-row { display:flex; align-items:center; gap:10px; } .aq-row label { color:#8b949e; font-size:12px; display:flex; align-items:center; gap:6px; } .aq-row input[type=number] { background:#0d1117; color:#e6edf3; border:1px solid #30363d; border-radius:6px; padding:4px 6px; width:70px; font-family:inherit; } .aq-spacer { flex:1; } .aq-queue-box { background:#0d1117; border:1px solid #30363d; border-radius:10px; padding:10px; } .aq-queue-header { display:flex; justify-content:space-between; color:#8b949e; font-size:11px; text-transform:uppercase; margin-bottom:6px; } .aq-queue { display:flex; flex-wrap:wrap; gap:6px; min-height:32px; } .aq-chip { display:inline-flex; align-items:center; gap:6px; padding:4px 8px; border-radius:14px; background:#21262d; border:1px solid #30363d; color:#e6edf3; font-family:ui-monospace,monospace; font-size:12px; } .aq-chip button { background:transparent; border:none; color:#8b949e; cursor:pointer; font-size:14px; line-height:1; padding:0; } .aq-chip button:hover { color:#ff7b72; } .aq-empty { color:#6e7681; font-style:italic; padding:4px 0; } .aq-dsl { margin-top:8px; padding:6px 8px; background:#010409; border-radius:6px; color:#7ee787; font-family:ui-monospace,monospace; font-size:12px; word-break:break-all; min-height:18px; } """ JS_ON_LOAD = r""" (() => { const root = element; const framesInput = root.querySelector('#aq-frames'); const backBtn = root.querySelector('#aq-back'); const clearBtn = root.querySelector('#aq-clear'); const queueEl = root.querySelector('#aq-queue'); const totalEl = root.querySelector('#aq-total'); const dslEl = root.querySelector('#aq-dsl'); // ----- queue state ------------------------------------------------------ let queue = Array.isArray(props.value) ? props.value.slice() : []; const MAX_FRAMES = 160; // num_frames slider max is 161 = MAX_FRAMES + 1 function render() { queueEl.innerHTML = ''; if (queue.length === 0) { const e = document.createElement('span'); e.className = 'aq-empty'; e.textContent = 'No segments yet — tap a button above.'; queueEl.appendChild(e); } else { queue.forEach((seg, i) => { const chip = document.createElement('span'); chip.className = 'aq-chip'; chip.textContent = `${seg.keys || 'none'}-${seg.frames}`; const x = document.createElement('button'); x.textContent = '×'; x.title = 'Remove'; x.addEventListener('click', (ev) => { ev.stopPropagation(); queue.splice(i, 1); commit(); }); chip.appendChild(x); queueEl.appendChild(chip); }); } const total = queue.reduce((s, x) => s + (x.frames | 0), 0); totalEl.textContent = total ? `${queue.length} segment${queue.length > 1 ? 's' : ''} · ${total}/${MAX_FRAMES} frames` : `0/${MAX_FRAMES} frames`; totalEl.style.color = total >= MAX_FRAMES ? '#ff7b72' : ''; dslEl.textContent = queue.length ? queue.map(s => `${(s.keys || 'none')}-${s.frames}`).join(',') : '(empty)'; } function commit() { render(); props.value = queue.slice(); trigger('change', queue.slice()); } function appendSegment(keys) { const requested = Math.max(1, parseInt(framesInput.value || '16', 10)); const current = queue.reduce((s, x) => s + (x.frames | 0), 0); const remaining = MAX_FRAMES - current; if (remaining <= 0) return; // queue full const frames = Math.min(requested, remaining); // Coalesce consecutive identical segments. const last = queue[queue.length - 1]; if (last && last.keys === keys) { last.frames += frames; } else { queue.push({ keys, frames }); } commit(); } root.querySelectorAll('.aq-btn[data-k]').forEach(btn => { btn.addEventListener('click', () => appendSegment(btn.dataset.k)); }); backBtn.addEventListener('click', () => { queue.pop(); commit(); }); clearBtn.addEventListener('click', () => { queue = []; commit(); }); // Keyboard shortcuts (when focus is anywhere in the queue block). root.tabIndex = 0; root.addEventListener('keydown', (e) => { const k = e.key.toLowerCase(); if ('wasdijkl'.includes(k)) { appendSegment(k); e.preventDefault(); } else if (k === ' ') { appendSegment('none'); e.preventDefault(); } else if (k === 'backspace') { queue.pop(); commit(); e.preventDefault(); } }); // React to external value changes (Examples click, programmatic updates). function parseDsl(s) { return s.split(',').map(seg => { const idx = seg.lastIndexOf('-'); if (idx < 0) return null; const keys = seg.slice(0, idx).toLowerCase().trim(); const frames = parseInt(seg.slice(idx + 1), 10); if (!keys || !Number.isFinite(frames) || frames < 1) return null; return { keys, frames }; }).filter(Boolean); } watch('value', () => { let incoming; if (Array.isArray(props.value)) { incoming = props.value; } else if (typeof props.value === 'string') { incoming = parseDsl(props.value); } else { incoming = []; } // Skip if the incoming value matches our local state — avoids feedback loops // from our own commit() calls. if (JSON.stringify(incoming) === JSON.stringify(queue)) return; queue = incoming.map(s => ({ keys: (s.keys || 'none'), frames: Math.max(1, s.frames | 0) })); render(); }); render(); })(); """ class ActionQueue(gr.HTML): """WASD/IJKL action-queue input. value = list[{keys: str, frames: int}].""" def __init__(self, value=None, **kwargs): if value is None: value = [] super().__init__( value=value, html_template=HTML_TEMPLATE, css_template=CSS_TEMPLATE, js_on_load=JS_ON_LOAD, **kwargs, ) def api_info(self): return { "type": "array", "items": { "type": "object", "properties": { "keys": {"type": "string"}, "frames": {"type": "integer", "minimum": 1}, }, "required": ["keys", "frames"], }, } # --------------------------------------------------------------------------- # Examples # --------------------------------------------------------------------------- EXAMPLE_DIR = SANA_DIR / "asset" / "sana_wm" def _example_queue(dsl: str) -> list[dict]: out = [] for seg in dsl.split(","): keys, frames = seg.rsplit("-", 1) out.append({"keys": keys.lower(), "frames": int(frames)}) return out def _load_example_prompt(name: str) -> str: p = EXAMPLE_DIR / f"{name}.txt" return p.read_text(encoding="utf-8", errors="replace").strip() if p.exists() else "" EXAMPLES = [] for stem, dsl in [ ("demo_0", "w-24,jw-24,w-24"), ("demo_1", "w-32,iw-16,w-24"), ("demo_2", "w-16,lw-16,w-16,jw-16,w-16"), ]: img_path = EXAMPLE_DIR / f"{stem}.png" if img_path.exists(): EXAMPLES.append([str(img_path), _load_example_prompt(stem), dsl]) # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- DESCRIPTION = """ # 🌍 SANA-WM — Camera-Controlled World Model Image-to-video generation with 6-DoF camera control using [`Efficient-Large-Model/SANA-WM_bidirectional`](https://huggingface.co/Efficient-Large-Model/SANA-WM_bidirectional) and the [NVlabs/Sana](https://github.com/NVlabs/Sana) Stage-1 pipeline. """ with gr.Blocks(theme=gr.themes.Soft(), title="SANA-WM Demo") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): image_in = gr.Image(label="First frame", type="pil", height=300) prompt_in = gr.Textbox( label="Prompt", lines=4, placeholder="A first-person view from a strictly stationary observation point...", ) gr.Markdown("### 🎮 Camera action queue") action_queue = ActionQueue(value=[{"keys": "w", "frames": 24}]) dsl_view = gr.Textbox(label="Resulting DSL", interactive=False) run_btn = gr.Button("🚀 Generate", variant="primary", size="lg") with gr.Accordion("Advanced", open=False): num_frames = gr.Slider( 9, 161, value=41, step=8, label="num_frames (auto-tracks queue length, LTX-2 VAE → 8k+1)", info="Adjust to override the auto-sync; will be snapped to 8k+1.", ) fps = gr.Slider(8, 24, value=16, step=1, label="fps") steps = gr.Slider(10, 80, value=40, step=2, label="DiT sampling steps") cfg_scale = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="cfg_scale") translation_speed = gr.Slider( 0.01, 0.15, value=DEFAULT_TRANSLATION_SPEED, step=0.005, label="translation_speed (per frame)" ) rotation_speed_deg = gr.Slider( 0.2, 5.0, value=DEFAULT_ROTATION_SPEED_DEG, step=0.1, label="rotation_speed_deg (per frame)" ) seed = gr.Number(value=42, precision=0, label="seed") show_overlay = gr.Checkbox(value=True, label="Composite Genie-style action overlay on the output") with gr.Column(scale=1): video_out = gr.Video(label="Output (704×1280)", height=520, autoplay=True) gr.Markdown( "_The output is the Sana VAE decode of Stage-1 latents (no refiner). " "For peak quality use the full pipeline with `--no_refiner` disabled offline._" ) if EXAMPLES: # action_queue accepts DSL strings (normalized in Python and parsed in # the component's watch hook), so we can feed the DSL string straight # in — gives a readable table column. cache_examples=lazy runs infer() # on first click per row and caches the output mp4. gr.Examples( examples=EXAMPLES, inputs=[image_in, prompt_in, action_queue], outputs=[video_out, dsl_view], fn=infer, cache_examples=True, cache_mode="lazy", label="Example image + prompt + camera action queue (lazy-cached)", ) # Mirror the queue's DSL into the read-only textbox. action_queue.change( fn=lambda q: queue_to_dsl(q or []), inputs=[action_queue], outputs=[dsl_view], ) # Auto-sync num_frames to the queue total (snapped to 8k+1, clipped to slider). def _sync_num_frames(q: list) -> int: total = queue_total_frames(q or []) if total < 1: return 9 snapped = _snap_num_frames(total + 1, stride=8, upper_bound=total + 1) return max(9, min(161, snapped)) action_queue.change( fn=_sync_num_frames, inputs=[action_queue], outputs=[num_frames], ) run_btn.click( fn=infer, inputs=[ image_in, prompt_in, action_queue, num_frames, fps, steps, cfg_scale, translation_speed, rotation_speed_deg, seed, show_overlay, ], outputs=[video_out, dsl_view], ) if __name__ == "__main__": head = '' demo.launch(head=head)