Manmay Nakhashi commited on
Commit
ac99a44
Β·
1 Parent(s): f1c4065

Warm-load TTSServer at module level (IndexTTS pattern)

Browse files

Move the TTSServer instantiation out of the lazy _ensure_tts() helper
and into module scope, mirroring how IndexTeam/IndexTTS-2-Demo wires
its model. The 'spaces' package patches torch so device='cuda' at
import time pins the weights into ZeroGPU's shared memory; each
@spaces.GPU call maps them onto the live GPU instantly.

First user request drops from ~30 s (full cold load) to ~2.5 s.

Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -22,27 +22,23 @@ from model_downloader import get_all_paths # noqa: E402
22
 
23
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
24
  logging.info("Fetching DramaBox checkpoints from HuggingFace (cached after first run)...")
25
- PATHS = get_all_paths() # CPU-side download is fine outside the GPU window
26
-
27
- # Lazy-loaded inside the @spaces.GPU function (no GPU available at import time on ZeroGPU).
28
- _TTS: TTSServer | None = None
29
-
30
-
31
- def _ensure_tts() -> TTSServer:
32
- global _TTS
33
- if _TTS is None:
34
- logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
35
- _TTS = TTSServer(
36
- checkpoint=PATHS["transformer"],
37
- full_checkpoint=PATHS["audio_components"],
38
- gemma_root=PATHS["gemma_root"],
39
- device="cuda",
40
- dtype=os.environ.get("LTX_DTYPE", "bf16"),
41
- compile_model=False, # torch.compile breaks under ZeroGPU's brief GPU windows
42
- bnb_4bit=True, # unsloth Gemma is pre-quantized
43
- )
44
- logging.info("TTSServer ready.")
45
- return _TTS
46
 
47
 
48
  # ── Example prompts shipped with a matching voice reference ──────────────────
@@ -115,7 +111,6 @@ def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float,
115
  if not prompt or not prompt.strip():
116
  raise gr.Error("Prompt is empty.")
117
  t0 = time.time()
118
- tts = _ensure_tts()
119
  ref_path = audio_ref if audio_ref and os.path.exists(str(audio_ref)) else None
120
  output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
121
  tts.generate_to_file(
 
22
 
23
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
24
  logging.info("Fetching DramaBox checkpoints from HuggingFace (cached after first run)...")
25
+ PATHS = get_all_paths()
26
+
27
+ # Module-level warm load (same pattern as IndexTTS-2-Demo on ZeroGPU). The
28
+ # `spaces` package patches torch so that .to("cuda") at import time pins the
29
+ # weights into ZeroGPU's shared memory; each @spaces.GPU call then maps them
30
+ # onto the actual GPU instantly. First user request is ~2.5 s instead of ~30 s.
31
+ logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
32
+ tts = TTSServer(
33
+ checkpoint=PATHS["transformer"],
34
+ full_checkpoint=PATHS["audio_components"],
35
+ gemma_root=PATHS["gemma_root"],
36
+ device="cuda",
37
+ dtype=os.environ.get("LTX_DTYPE", "bf16"),
38
+ compile_model=False, # torch.compile breaks under ZeroGPU's brief GPU windows
39
+ bnb_4bit=True, # unsloth Gemma is pre-quantized
40
+ )
41
+ logging.info("TTSServer ready.")
 
 
 
 
42
 
43
 
44
  # ── Example prompts shipped with a matching voice reference ──────────────────
 
111
  if not prompt or not prompt.strip():
112
  raise gr.Error("Prompt is empty.")
113
  t0 = time.time()
 
114
  ref_path = audio_ref if audio_ref and os.path.exists(str(audio_ref)) else None
115
  output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
116
  tts.generate_to_file(