"""ZeroGPU Gradio demo for SANA-WM (bidirectional). Loads ``Efficient-Large-Model/SANA-WM_bidirectional`` (Stage-1 DiT + LTX-2 VAE + gemma-2-2b-it text encoder, no refiner) and lets the user roll out a camera-controlled image-to-video clip from a WASD/IJKL action queue. """ from __future__ import annotations # ``spaces`` must be imported BEFORE anything that initialises CUDA (including # torch, mmcv, etc.). mmcv pulls torch in transitively, so even our runtime # install path below would touch CUDA emulation first if ``spaces`` weren't # already imported. import spaces # noqa: E402,F401 import os import sys import time import tempfile from pathlib import Path # Vendor the Sana PR branch as a sibling directory. We bundle the # feat/sana-wm branch into ./Sana when deploying to a Space. ROOT = Path(__file__).resolve().parent SANA_DIR = ROOT / "Sana" if str(SANA_DIR) not in sys.path: sys.path.insert(0, str(SANA_DIR)) # Must be set before any Sana / xformers import (see inference_sana_wm.py). os.environ.setdefault("DISABLE_XFORMERS", "1") # mmcv 1.7.2's setup.py imports ``pkg_resources`` (removed in setuptools 80+), # so it cannot be installed via Spaces' isolated pip build env. We follow the # Sana ``environment_setup.sh`` recipe and install it here on first cold-start: # pin setuptools<80, then build mmcv with ``--no-build-isolation`` so the # build sees the env's setuptools. Subsequent imports are no-ops. def _ensure_mmcv() -> None: try: import mmcv # noqa: F401 return except ImportError: pass import subprocess import sys as _sys print("[startup] mmcv not found — installing with --no-build-isolation...", flush=True) subprocess.check_call( [_sys.executable, "-m", "pip", "install", "--quiet", "setuptools<80", "wheel"] ) subprocess.check_call( [ _sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation", "mmcv==1.7.2", ] ) import mmcv # noqa: F401 print("[startup] mmcv installed.", flush=True) _ensure_mmcv() # flash-attn matches Sana's ``environment_setup.sh`` pin (``flash-attn>=2.7.0``). # ``FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE`` skips the nvcc kernel build, so the # wheel installs in seconds — but ``flash_attn/__init__.py`` then unconditionally # does ``import flash_attn_2_cuda``, which doesn't exist. We pre-stub that # CUDA-extension module in ``sys.modules`` so the import succeeds; SANA-WM uses # Triton GDN paths and never reaches the actual CUDA kernels. def _ensure_flash_attn() -> None: import types as _types import sys as _sys if "flash_attn_2_cuda" not in _sys.modules: _sys.modules["flash_attn_2_cuda"] = _types.ModuleType("flash_attn_2_cuda") try: import flash_attn # noqa: F401 return except ImportError: pass import subprocess print("[startup] flash-attn not found — installing (CUDA build skipped)...", flush=True) env = dict(os.environ) env["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE" subprocess.check_call( [ _sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation", "flash-attn>=2.7.0", ], env=env, ) if "flash_attn" in _sys.modules: del _sys.modules["flash_attn"] import flash_attn # noqa: F401 print("[startup] flash-attn installed.", flush=True) _ensure_flash_attn() import gradio as gr import imageio.v3 as iio import numpy as np import torch from PIL import Image # Sana-WM imports — registered by importing nets, then the pipeline module. import diffusion.model.nets # noqa: F401 from inference_video_scripts.inference_sana_wm import ( DEFAULT_ROTATION_SPEED_DEG, DEFAULT_TRANSLATION_SPEED, GenerationParams, InferenceConfig, SanaWMPipeline, TARGET_HEIGHT, TARGET_WIDTH, action_string_to_c2w, estimate_intrinsics_with_pi3x, resize_and_center_crop, transform_intrinsics_for_crop, ) from diffusion.utils.action_overlay import apply_overlay from diffusion.utils.logger import get_root_logger import pyrallis from sana.tools import resolve_hf_path HF_REPO = "Efficient-Large-Model/SANA-WM_bidirectional" CONFIG_URI = f"hf://{HF_REPO}/config.yaml" MODEL_URI = f"hf://{HF_REPO}/dit/sana_wm_1600m_720p.safetensors" LOGGER = get_root_logger() # --------------------------------------------------------------------------- # Pipeline — built at module load time. ZeroGPU's CUDA emulation lets us place # tensors on "cuda" without a real GPU; @spaces.GPU swaps in a real device for # the duration of the decorated call. Lazy-loading inside the decorator is # strongly discouraged on ZeroGPU. # --------------------------------------------------------------------------- print("[startup] Building SanaWMPipeline (downloads ~17 GB on first launch)...", flush=True) _t0 = time.time() PIPELINE_CONFIG: InferenceConfig = pyrallis.parse( config_class=InferenceConfig, config_path=resolve_hf_path(CONFIG_URI), args=[] ) PIPELINE = SanaWMPipeline( config=PIPELINE_CONFIG, model_path=resolve_hf_path(MODEL_URI), device="cuda", refiner=None, logger=LOGGER, ) print(f"[startup] Pipeline ready in {time.time() - _t0:.1f}s.", flush=True) # --------------------------------------------------------------------------- # Action-queue helpers (queue list <-> DSL string) # --------------------------------------------------------------------------- ALLOWED_KEYS = set("wasdijkl") def _normalize_queue(queue) -> list[dict]: """Accept either a list[dict] (from the UI), a DSL string (from examples / serialized state), or anything falsy. Always returns a list of dicts.""" if not queue: return [] if isinstance(queue, str): out = [] for seg in queue.split(","): seg = seg.strip() if not seg or "-" not in seg: continue keys, frames = seg.rsplit("-", 1) try: out.append({"keys": keys.lower(), "frames": max(1, int(frames))}) except ValueError: continue return out if isinstance(queue, list): return [s for s in queue if isinstance(s, dict)] return [] def queue_to_dsl(queue) -> str: """[{keys: 'wj', frames: 16}, ...] -> 'wj-16,...' — also accepts a DSL str.""" queue = _normalize_queue(queue) if not queue: return "" parts = [] for seg in queue: keys = (seg.get("keys") or "").lower().strip() keys = "".join(sorted(set(c for c in keys if c in ALLOWED_KEYS))) or "none" frames = max(1, int(seg.get("frames", 1))) parts.append(f"{keys}-{frames}") return ",".join(parts) def queue_total_frames(queue) -> int: queue = _normalize_queue(queue) return sum(max(1, int(s.get("frames", 1))) for s in queue) def _snap_num_frames(n: int, stride: int = 8, upper_bound: int | None = None) -> int: """LTX-2 VAE wants num_frames = 8*k + 1.""" if n < 1: return 1 if (n - 1) % stride == 0: return n floor_cand = n - ((n - 1) % stride) ceil_cand = floor_cand + stride snapped = floor_cand if (n - floor_cand) < (ceil_cand - n) else ceil_cand if upper_bound is not None and snapped > upper_bound: snapped = floor_cand return max(snapped, 1) # --------------------------------------------------------------------------- # tqdm <-> gr.Progress bridge # # Sana's samplers do ``from tqdm import tqdm`` at module load, so the symbol # binds before our patch runs. We swap that *bound name* on the sampler module # with a subclass that forwards progress on every update, then restore on exit. # Avoids ``track_tqdm=True``, which breaks inside ZeroGPU subprocesses. # --------------------------------------------------------------------------- import contextlib import importlib _SANA_TQDM_TARGETS = [ "diffusion.scheduler.flow_euler_sampler", "diffusion.scheduler.dpm_solver", "diffusion.model.dpm_solver", ] @contextlib.contextmanager def _bridge_sana_tqdm(progress: "gr.Progress", *, span=(0.4, 0.9), desc_prefix="DiT"): """Forward Sana's tqdm ticks to the Gradio progress bar.""" import tqdm as _tqdm_mod lo, hi = span class _BridgedTqdm(_tqdm_mod.tqdm): def update(self, n=1): ret = super().update(n) try: if self.total: frac = max(0.0, min(1.0, self.n / self.total)) progress(lo + (hi - lo) * frac, desc=f"{desc_prefix} step {self.n}/{self.total}") except Exception: pass return ret saved: dict[str, object] = {} for mod_name in _SANA_TQDM_TARGETS: try: mod = importlib.import_module(mod_name) except Exception: continue if hasattr(mod, "tqdm"): saved[mod_name] = mod.tqdm mod.tqdm = _BridgedTqdm try: yield finally: for mod_name, original in saved.items(): try: importlib.import_module(mod_name).tqdm = original except Exception: pass # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- @spaces.GPU(duration=210) def infer( image: Image.Image, prompt: str, queue_state, num_frames: int = 161, fps: int = 16, steps: int = 40, cfg_scale: float = 5.0, translation_speed: float = DEFAULT_TRANSLATION_SPEED, rotation_speed_deg: float = DEFAULT_ROTATION_SPEED_DEG, seed: int = 42, show_overlay: bool = True, progress: gr.Progress = gr.Progress(), ): if image is None: raise gr.Error("Please upload a first-frame image.") if not prompt or not prompt.strip(): raise gr.Error("Please enter a prompt.") dsl = queue_to_dsl(queue_state or []) if not dsl: raise gr.Error("Action queue is empty — add at least one segment.") progress(0.05, desc="Rolling out trajectory") c2w_full = action_string_to_c2w( dsl, translation_speed=translation_speed, rotation_speed_deg=rotation_speed_deg, ) upper = c2w_full.shape[0] snapped = _snap_num_frames(min(int(num_frames), upper), stride=8, upper_bound=upper) if snapped < 9: raise gr.Error( f"Need at least 9 frames after LTX-2 snapping (8k+1); your queue rolled out {upper} frames." ) c2w = c2w_full[:snapped] progress(0.1, desc="Preparing image") cropped, src_size, resized_size, crop_offset = resize_and_center_crop(image) progress(0.15, desc="Estimating intrinsics (Pi3X)") device = torch.device("cuda") intr_one = estimate_intrinsics_with_pi3x(image, device, LOGGER) intr_src = np.broadcast_to(intr_one, (snapped, 4)).copy() intrinsics_vec4 = transform_intrinsics_for_crop(intr_src, src_size, resized_size, crop_offset) params = GenerationParams( num_frames=snapped, fps=int(fps), step=int(steps), cfg_scale=float(cfg_scale), seed=int(seed), sampling_algo="flow_euler_ltx", ) progress(0.4, desc=f"Sampling {snapped} frames in {steps} steps") t0 = time.time() with _bridge_sana_tqdm(progress, span=(0.4, 0.9), desc_prefix="DiT"): out = PIPELINE.generate(cropped, prompt.strip(), c2w, intrinsics_vec4, params) LOGGER.info(f"Sampling done in {time.time() - t0:.1f}s") video_hwc = out["video"] if show_overlay: progress(0.92, desc="Compositing action overlay") video_hwc = apply_overlay(video_hwc, out["c2w"]) out_path = Path(tempfile.mkdtemp()) / "sana_wm.mp4" iio.imwrite(out_path, video_hwc, fps=params.fps) return str(out_path), dsl # --------------------------------------------------------------------------- # ActionQueue: custom gr.HTML component # - 8 directional buttons (WASD = translate, IJKL = rotate) + "idle" # - Frames-per-tap input # - Removable queue chips # - Three.js top-down trajectory preview that mirrors action_string_to_c2w # --------------------------------------------------------------------------- HTML_TEMPLATE = """