File size: 6,233 Bytes
0422215
4230483
 
 
 
 
 
 
 
 
 
 
 
 
 
0422215
 
 
4230483
0422215
4230483
 
 
 
 
0422215
 
4230483
 
 
 
 
d79393d
 
 
 
 
4230483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0422215
d105ee2
4230483
d105ee2
4230483
 
d105ee2
4230483
 
d105ee2
4230483
 
 
d105ee2
 
 
4230483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23668b5
 
 
4230483
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
Dramabox Space entrypoint β€” pure Gradio 5.x for ZeroGPU compat.

Why no FastAPI mount:
- ZeroGPU only allocates a GPU for @spaces.GPU functions wired into Gradio
  events (button.click / Interface inputs). FastAPI-mounted endpoints
  don't trigger HF's ZeroGPU scheduler, and the mounting pattern was
  also causing HF's runtime to kill the container after startup.
- This file mirrors the upstream ResembleAI/Dramabox Space's app.py.
- The React frontend (DramaboxTool.tsx) calls the named API endpoint
  via `@gradio/client` instead of fetch().

Dramabox checkpoints are lazy-loaded on the first request so the Space
boots even before `dramabox_src/` is vendored β€” first call will surface
the import error to the caller, subsequent calls reuse the warm server.
"""
from __future__ import annotations

import logging
import os
import sys
import tempfile
import threading
import time
from pathlib import Path

import gradio as gr
import spaces

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")

# Vendored Dramabox source. Resemble doesn't publish TTSServer to PyPI;
# `dramabox_src/` mirrors the upstream Space layout: `src/` (inference glue)
# alongside `ltx2/` (LTX-2 core packages). `inference_server.py` itself does
# `sys.path.insert(0, APP_DIR/'ltx2')` where APP_DIR = parent.parent, so we
# only need to put `dramabox_src/src/` on sys.path here.
_VENDORED_SRC = Path(__file__).parent / "dramabox_src" / "src"
if _VENDORED_SRC.exists() and str(_VENDORED_SRC) not in sys.path:
    sys.path.insert(0, str(_VENDORED_SRC))

_tts_lock = threading.Lock()
_tts = None  # populated lazily on first on_generate() call


def _get_tts():
    """Load TTSServer once, reuse across calls. Surfaces a clean error
    if `dramabox_src/` isn't vendored β€” caller sees a gr.Error toast."""
    global _tts
    if _tts is not None:
        return _tts
    with _tts_lock:
        if _tts is not None:
            return _tts
        try:
            from inference_server import TTSServer  # type: ignore[import-not-found]
            from model_downloader import get_all_paths  # type: ignore[import-not-found]
        except ImportError as e:
            raise gr.Error(
                "Dramabox source not vendored on this Space. Copy "
                "ResembleAI/Dramabox's src/ into the repo as dramabox_src/."
            ) from e

        logging.info("Fetching Dramabox checkpoints (cached after first run)...")
        paths = get_all_paths()

        logging.info("Loading Dramabox warm server (Gemma + DiT + VAE + Decoder)...")
        _tts = TTSServer(
            checkpoint=paths["transformer"],
            full_checkpoint=paths["audio_components"],
            gemma_root=paths["gemma_root"],
            device="cuda",
            dtype=os.environ.get("LTX_DTYPE", "bf16"),
            compile_model=False,   # torch.compile breaks under ZeroGPU's brief GPU windows
            bnb_4bit=True,         # unsloth Gemma is pre-quantized
        )
        logging.info("Dramabox TTSServer ready.")
    return _tts


@spaces.GPU(duration=60)
def on_generate(prompt, audio_ref, cfg, stg, dur_mult, gen_dur, ref_dur, seed):
    """Main generation endpoint β€” wired to the Generate button below so
    HF's ZeroGPU scheduler detects it at import time."""
    if not prompt or not prompt.strip():
        raise gr.Error("Prompt is empty.")
    tts = _get_tts()
    t0 = time.time()
    ref_path = audio_ref if audio_ref and os.path.exists(str(audio_ref)) else None
    output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
    tts.generate_to_file(
        prompt=prompt,
        output=output,
        voice_ref=ref_path,
        cfg_scale=float(cfg),
        stg_scale=float(stg),
        duration_multiplier=float(dur_mult),
        seed=int(seed),
        gen_duration=float(gen_dur),
        ref_duration=float(ref_dur),
    )
    elapsed = time.time() - t0
    logging.info(f"Dramabox generated in {elapsed:.2f}s -> {output}")
    return output


with gr.Blocks(title="VideoVoice Dramabox") as demo:
    gr.Markdown(
        """
        # VideoVoice β€” Dramabox

        Resemble AI's directable speech engine ("scene prompts" with quoted
        dialogue and stage directions). The React frontend at
        [videovoice.app/app/dramabox](https://videovoice.app/app/dramabox)
        is the primary UI; this Space exposes the model via the named
        `/dramabox` Gradio API endpoint, called from the React app through
        `@gradio/client`.
        """
    )

    with gr.Row():
        with gr.Column(scale=3):
            prompt_in = gr.Textbox(
                label="Scene prompt",
                placeholder='A weary detective, "I told you it was him." He sighs. "Every time."',
                lines=6,
            )
            audio_ref_in = gr.Audio(
                label="Voice reference (optional, 10+ seconds)",
                type="filepath",
            )
            gen_btn = gr.Button("Generate", variant="primary", size="lg")
        with gr.Column(scale=2):
            with gr.Accordion("Inference settings", open=True):
                cfg_in = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="CFG scale")
                stg_in = gr.Slider(0.0, 5.0, value=1.5, step=0.5, label="STG scale")
                dur_mult_in = gr.Slider(
                    0.8, 2.0, value=1.1, step=0.05,
                    label="Duration Γ— (only used when target duration = 0)",
                )
                gen_dur_in = gr.Slider(
                    0.0, 60.0, value=0.0, step=1.0,
                    label="Target duration (s) β€” 0 = auto",
                )
                ref_dur_in = gr.Slider(
                    3.0, 30.0, value=10.0, step=1.0,
                    label="Reference duration (s)",
                )
                seed_in = gr.Number(value=42, label="Seed", precision=0)
            audio_out = gr.Audio(label="Generated audio", type="filepath")

    gen_btn.click(
        on_generate,
        inputs=[prompt_in, audio_ref_in, cfg_in, stg_in,
                dur_mult_in, gen_dur_in, ref_dur_in, seed_in],
        outputs=[audio_out],
        api_name="dramabox",
    )


if __name__ == "__main__":
    demo.queue().launch()