github-actions[bot]
deploy: switch to dramabox requirements @ f090cde
d79393d
"""
Dramabox — Resemble AI directable speech engine.
Single-Space tool: generates a 48 kHz WAV "performance" from a scene prompt
(quoted dialogue + stage directions) and an optional voice reference. Mirrors
the official ResembleAI/Dramabox Space's on_generate(): same parameter order,
same defaults, same model invocation.
This module only runs on the videovoice-dramabox Space, which must vendor the
Dramabox `src/` directory (inference_server.py + model_downloader.py) and the
requirements-dramabox.txt deps. On any other Space the lazy import below
raises a clean RuntimeError rather than crashing app startup.
The module loads the TTSServer once on first request (warm-load pattern from
the upstream Space) and reuses it across calls.
"""
from __future__ import annotations
import logging
import os
import threading
import time
from pathlib import Path
import spaces
# Backend env knobs — kept compatible with the upstream Space.
_LTX_DTYPE = os.environ.get("LTX_DTYPE", "bf16")
# Module-level warm load, guarded by a lock so a flurry of concurrent first
# requests only triggers one load. Subsequent calls are ~2.5s on warm GPU.
_tts_lock = threading.Lock()
_tts_server = None # populated lazily on first generate() call
logger = logging.getLogger("tools_api.dramabox")
def _ensure_server():
"""Lazy-import the Dramabox model + load checkpoints once. Raises a clean
RuntimeError on Spaces that don't ship the Dramabox `src/` vendoring.
"""
global _tts_server
if _tts_server is not None:
return _tts_server
with _tts_lock:
if _tts_server is not None:
return _tts_server
try:
# Vendored from ResembleAI/Dramabox; the Space's `src/` must be on
# sys.path. We add it here so this module doesn't require app.py
# to do the insert itself.
import sys
# Match upstream layout: src/ holds inference_server.py which
# then puts the sibling ltx2/ on sys.path itself.
vendored_src = Path(__file__).parent.parent / "dramabox_src" / "src"
if vendored_src.exists() and str(vendored_src) not in sys.path:
sys.path.insert(0, str(vendored_src))
from inference_server import TTSServer # type: ignore[import-not-found]
from model_downloader import get_all_paths # type: ignore[import-not-found]
except ImportError as e:
raise RuntimeError(
"Dramabox is not installed on this Space. Vendor "
"ResembleAI/Dramabox's src/ directory at "
"VideoVoice-be/dramabox_src/ and install requirements-dramabox.txt."
) from e
logger.info("Fetching Dramabox checkpoints (cached after first run)...")
paths = get_all_paths()
logger.info("Loading Dramabox warm server (Gemma + DiT + VAE + Decoder)...")
_tts_server = TTSServer(
checkpoint=paths["transformer"],
full_checkpoint=paths["audio_components"],
gemma_root=paths["gemma_root"],
device="cuda",
dtype=_LTX_DTYPE,
compile_model=False, # torch.compile breaks under ZeroGPU's brief GPU windows
bnb_4bit=True, # unsloth Gemma is pre-quantized
)
logger.info("Dramabox TTSServer ready.")
return _tts_server
@spaces.GPU(duration=60)
def _generate_scene_gpu(
*,
prompt: str,
out_dir: Path,
audio_ref: Path | None,
cfg: float,
stg: float,
dur_mult: float,
gen_dur: float,
ref_dur: float,
seed: int,
) -> dict:
"""Top-level ZeroGPU wrapper so HF detects Dramabox GPU usage at startup."""
return _generate_impl(
prompt=prompt,
out_dir=out_dir,
audio_ref=audio_ref,
cfg=cfg,
stg=stg,
dur_mult=dur_mult,
gen_dur=gen_dur,
ref_dur=ref_dur,
seed=seed,
)
def generate_scene(
*,
prompt: str,
out_dir: Path,
audio_ref: Path | None = None,
cfg: float = 2.5,
stg: float = 1.5,
dur_mult: float = 1.1,
gen_dur: float = 0.0,
ref_dur: float = 10.0,
seed: int = 42,
) -> dict:
"""
Run Dramabox on `prompt` and write the resulting WAV under `out_dir`.
Returns:
{
"filename": "dramabox_<run_id_short>.wav",
"elapsed": <seconds>,
"settings": {...echo of inputs used...},
}
"""
prompt = (prompt or "").strip()
if not prompt:
raise ValueError("Prompt is empty.")
return _generate_scene_gpu(
prompt=prompt,
out_dir=out_dir,
audio_ref=audio_ref,
cfg=cfg,
stg=stg,
dur_mult=dur_mult,
gen_dur=gen_dur,
ref_dur=ref_dur,
seed=seed,
)
def _generate_impl(
*,
prompt: str,
out_dir: Path,
audio_ref: Path | None,
cfg: float,
stg: float,
dur_mult: float,
gen_dur: float,
ref_dur: float,
seed: int,
) -> dict:
tts = _ensure_server()
out_dir.mkdir(parents=True, exist_ok=True)
output = out_dir / f"dramabox_{int(time.time() * 1000)}.wav"
ref_path: str | None = None
if audio_ref is not None and Path(audio_ref).exists():
ref_path = str(audio_ref)
t0 = time.time()
tts.generate_to_file(
prompt=prompt,
output=str(output),
voice_ref=ref_path,
cfg_scale=float(cfg),
stg_scale=float(stg),
duration_multiplier=float(dur_mult),
seed=int(seed),
gen_duration=float(gen_dur),
ref_duration=float(ref_dur),
)
elapsed = time.time() - t0
logger.info(f"Dramabox generated in {elapsed:.2f}s -> {output}")
return {
"filename": output.name,
"elapsed": elapsed,
"settings": {
"cfg": cfg,
"stg": stg,
"dur_mult": dur_mult,
"gen_dur": gen_dur,
"ref_dur": ref_dur,
"seed": seed,
"had_voice_ref": ref_path is not None,
},
}