Spaces:
Running on Zero
Running on Zero
File size: 6,036 Bytes
12ab2ca 30a5c8e 12ab2ca 8bfce29 12ab2ca 30a5c8e 12ab2ca 30a5c8e 12ab2ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """
Dramabox — Resemble AI directable speech engine.
Single-Space tool: generates a 48 kHz WAV "performance" from a scene prompt
(quoted dialogue + stage directions) and an optional voice reference. Mirrors
the official ResembleAI/Dramabox Space's on_generate(): same parameter order,
same defaults, same model invocation.
This module only runs on the videovoice-dramabox Space, which must vendor the
Dramabox `src/` directory (inference_server.py + model_downloader.py) and the
requirements-dramabox.txt deps. On any other Space the lazy import below
raises a clean RuntimeError rather than crashing app startup.
The module loads the TTSServer once on first request (warm-load pattern from
the upstream Space) and reuses it across calls.
"""
from __future__ import annotations
import logging
import os
import threading
import time
from pathlib import Path
import spaces
# Backend env knobs — kept compatible with the upstream Space.
_LTX_DTYPE = os.environ.get("LTX_DTYPE", "bf16")
# Module-level warm load, guarded by a lock so a flurry of concurrent first
# requests only triggers one load. Subsequent calls are ~2.5s on warm GPU.
_tts_lock = threading.Lock()
_tts_server = None # populated lazily on first generate() call
logger = logging.getLogger("tools_api.dramabox")
def _ensure_server():
"""Lazy-import the Dramabox model + load checkpoints once. Raises a clean
RuntimeError on Spaces that don't ship the Dramabox `src/` vendoring.
"""
global _tts_server
if _tts_server is not None:
return _tts_server
with _tts_lock:
if _tts_server is not None:
return _tts_server
try:
# Vendored from ResembleAI/Dramabox; the Space's `src/` must be on
# sys.path. We add it here so this module doesn't require app.py
# to do the insert itself.
import sys
# Match upstream layout: src/ holds inference_server.py which
# then puts the sibling ltx2/ on sys.path itself.
vendored_src = Path(__file__).parent.parent / "dramabox_src" / "src"
if vendored_src.exists() and str(vendored_src) not in sys.path:
sys.path.insert(0, str(vendored_src))
from inference_server import TTSServer # type: ignore[import-not-found]
from model_downloader import get_all_paths # type: ignore[import-not-found]
except ImportError as e:
raise RuntimeError(
"Dramabox is not installed on this Space. Vendor "
"ResembleAI/Dramabox's src/ directory at "
"VideoVoice-be/dramabox_src/ and install requirements-dramabox.txt."
) from e
logger.info("Fetching Dramabox checkpoints (cached after first run)...")
paths = get_all_paths()
logger.info("Loading Dramabox warm server (Gemma + DiT + VAE + Decoder)...")
_tts_server = TTSServer(
checkpoint=paths["transformer"],
full_checkpoint=paths["audio_components"],
gemma_root=paths["gemma_root"],
device="cuda",
dtype=_LTX_DTYPE,
compile_model=False, # torch.compile breaks under ZeroGPU's brief GPU windows
bnb_4bit=True, # unsloth Gemma is pre-quantized
)
logger.info("Dramabox TTSServer ready.")
return _tts_server
@spaces.GPU(duration=60)
def _generate_scene_gpu(
*,
prompt: str,
out_dir: Path,
audio_ref: Path | None,
cfg: float,
stg: float,
dur_mult: float,
gen_dur: float,
ref_dur: float,
seed: int,
) -> dict:
"""Top-level ZeroGPU wrapper so HF detects Dramabox GPU usage at startup."""
return _generate_impl(
prompt=prompt,
out_dir=out_dir,
audio_ref=audio_ref,
cfg=cfg,
stg=stg,
dur_mult=dur_mult,
gen_dur=gen_dur,
ref_dur=ref_dur,
seed=seed,
)
def generate_scene(
*,
prompt: str,
out_dir: Path,
audio_ref: Path | None = None,
cfg: float = 2.5,
stg: float = 1.5,
dur_mult: float = 1.1,
gen_dur: float = 0.0,
ref_dur: float = 10.0,
seed: int = 42,
) -> dict:
"""
Run Dramabox on `prompt` and write the resulting WAV under `out_dir`.
Returns:
{
"filename": "dramabox_<run_id_short>.wav",
"elapsed": <seconds>,
"settings": {...echo of inputs used...},
}
"""
prompt = (prompt or "").strip()
if not prompt:
raise ValueError("Prompt is empty.")
return _generate_scene_gpu(
prompt=prompt,
out_dir=out_dir,
audio_ref=audio_ref,
cfg=cfg,
stg=stg,
dur_mult=dur_mult,
gen_dur=gen_dur,
ref_dur=ref_dur,
seed=seed,
)
def _generate_impl(
*,
prompt: str,
out_dir: Path,
audio_ref: Path | None,
cfg: float,
stg: float,
dur_mult: float,
gen_dur: float,
ref_dur: float,
seed: int,
) -> dict:
tts = _ensure_server()
out_dir.mkdir(parents=True, exist_ok=True)
output = out_dir / f"dramabox_{int(time.time() * 1000)}.wav"
ref_path: str | None = None
if audio_ref is not None and Path(audio_ref).exists():
ref_path = str(audio_ref)
t0 = time.time()
tts.generate_to_file(
prompt=prompt,
output=str(output),
voice_ref=ref_path,
cfg_scale=float(cfg),
stg_scale=float(stg),
duration_multiplier=float(dur_mult),
seed=int(seed),
gen_duration=float(gen_dur),
ref_duration=float(ref_dur),
)
elapsed = time.time() - t0
logger.info(f"Dramabox generated in {elapsed:.2f}s -> {output}")
return {
"filename": output.name,
"elapsed": elapsed,
"settings": {
"cfg": cfg,
"stg": stg,
"dur_mult": dur_mult,
"gen_dur": gen_dur,
"ref_dur": ref_dur,
"seed": seed,
"had_voice_ref": ref_path is not None,
},
}
|