File size: 6,036 Bytes
12ab2ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a5c8e
 
12ab2ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bfce29
 
 
12ab2ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a5c8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12ab2ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a5c8e
 
 
 
 
 
 
 
 
 
 
12ab2ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
Dramabox — Resemble AI directable speech engine.

Single-Space tool: generates a 48 kHz WAV "performance" from a scene prompt
(quoted dialogue + stage directions) and an optional voice reference. Mirrors
the official ResembleAI/Dramabox Space's on_generate(): same parameter order,
same defaults, same model invocation.

This module only runs on the videovoice-dramabox Space, which must vendor the
Dramabox `src/` directory (inference_server.py + model_downloader.py) and the
requirements-dramabox.txt deps. On any other Space the lazy import below
raises a clean RuntimeError rather than crashing app startup.

The module loads the TTSServer once on first request (warm-load pattern from
the upstream Space) and reuses it across calls.
"""
from __future__ import annotations

import logging
import os
import threading
import time
from pathlib import Path

import spaces

# Backend env knobs — kept compatible with the upstream Space.
_LTX_DTYPE = os.environ.get("LTX_DTYPE", "bf16")

# Module-level warm load, guarded by a lock so a flurry of concurrent first
# requests only triggers one load. Subsequent calls are ~2.5s on warm GPU.
_tts_lock = threading.Lock()
_tts_server = None  # populated lazily on first generate() call

logger = logging.getLogger("tools_api.dramabox")


def _ensure_server():
    """Lazy-import the Dramabox model + load checkpoints once. Raises a clean
    RuntimeError on Spaces that don't ship the Dramabox `src/` vendoring.
    """
    global _tts_server
    if _tts_server is not None:
        return _tts_server

    with _tts_lock:
        if _tts_server is not None:
            return _tts_server

        try:
            # Vendored from ResembleAI/Dramabox; the Space's `src/` must be on
            # sys.path. We add it here so this module doesn't require app.py
            # to do the insert itself.
            import sys
            # Match upstream layout: src/ holds inference_server.py which
            # then puts the sibling ltx2/ on sys.path itself.
            vendored_src = Path(__file__).parent.parent / "dramabox_src" / "src"
            if vendored_src.exists() and str(vendored_src) not in sys.path:
                sys.path.insert(0, str(vendored_src))
            from inference_server import TTSServer  # type: ignore[import-not-found]
            from model_downloader import get_all_paths  # type: ignore[import-not-found]
        except ImportError as e:
            raise RuntimeError(
                "Dramabox is not installed on this Space. Vendor "
                "ResembleAI/Dramabox's src/ directory at "
                "VideoVoice-be/dramabox_src/ and install requirements-dramabox.txt."
            ) from e

        logger.info("Fetching Dramabox checkpoints (cached after first run)...")
        paths = get_all_paths()

        logger.info("Loading Dramabox warm server (Gemma + DiT + VAE + Decoder)...")
        _tts_server = TTSServer(
            checkpoint=paths["transformer"],
            full_checkpoint=paths["audio_components"],
            gemma_root=paths["gemma_root"],
            device="cuda",
            dtype=_LTX_DTYPE,
            compile_model=False,   # torch.compile breaks under ZeroGPU's brief GPU windows
            bnb_4bit=True,         # unsloth Gemma is pre-quantized
        )
        logger.info("Dramabox TTSServer ready.")
        return _tts_server


@spaces.GPU(duration=60)
def _generate_scene_gpu(
    *,
    prompt: str,
    out_dir: Path,
    audio_ref: Path | None,
    cfg: float,
    stg: float,
    dur_mult: float,
    gen_dur: float,
    ref_dur: float,
    seed: int,
) -> dict:
    """Top-level ZeroGPU wrapper so HF detects Dramabox GPU usage at startup."""
    return _generate_impl(
        prompt=prompt,
        out_dir=out_dir,
        audio_ref=audio_ref,
        cfg=cfg,
        stg=stg,
        dur_mult=dur_mult,
        gen_dur=gen_dur,
        ref_dur=ref_dur,
        seed=seed,
    )


def generate_scene(
    *,
    prompt: str,
    out_dir: Path,
    audio_ref: Path | None = None,
    cfg: float = 2.5,
    stg: float = 1.5,
    dur_mult: float = 1.1,
    gen_dur: float = 0.0,
    ref_dur: float = 10.0,
    seed: int = 42,
) -> dict:
    """
    Run Dramabox on `prompt` and write the resulting WAV under `out_dir`.

    Returns:
      {
        "filename": "dramabox_<run_id_short>.wav",
        "elapsed": <seconds>,
        "settings": {...echo of inputs used...},
      }
    """
    prompt = (prompt or "").strip()
    if not prompt:
        raise ValueError("Prompt is empty.")

    return _generate_scene_gpu(
        prompt=prompt,
        out_dir=out_dir,
        audio_ref=audio_ref,
        cfg=cfg,
        stg=stg,
        dur_mult=dur_mult,
        gen_dur=gen_dur,
        ref_dur=ref_dur,
        seed=seed,
    )


def _generate_impl(
    *,
    prompt: str,
    out_dir: Path,
    audio_ref: Path | None,
    cfg: float,
    stg: float,
    dur_mult: float,
    gen_dur: float,
    ref_dur: float,
    seed: int,
) -> dict:
    tts = _ensure_server()
    out_dir.mkdir(parents=True, exist_ok=True)
    output = out_dir / f"dramabox_{int(time.time() * 1000)}.wav"

    ref_path: str | None = None
    if audio_ref is not None and Path(audio_ref).exists():
        ref_path = str(audio_ref)

    t0 = time.time()
    tts.generate_to_file(
        prompt=prompt,
        output=str(output),
        voice_ref=ref_path,
        cfg_scale=float(cfg),
        stg_scale=float(stg),
        duration_multiplier=float(dur_mult),
        seed=int(seed),
        gen_duration=float(gen_dur),
        ref_duration=float(ref_dur),
    )
    elapsed = time.time() - t0
    logger.info(f"Dramabox generated in {elapsed:.2f}s -> {output}")

    return {
        "filename": output.name,
        "elapsed": elapsed,
        "settings": {
            "cfg": cfg,
            "stg": stg,
            "dur_mult": dur_mult,
            "gen_dur": gen_dur,
            "ref_dur": ref_dur,
            "seed": seed,
            "had_voice_ref": ref_path is not None,
        },
    }