Manmay Nakhashi commited on
Commit
b8b67ad
Β·
1 Parent(s): 8cd4942

Refactor for ZeroGPU: lazy TTSServer load + @spaces.GPU decorator

Browse files

- Add 'spaces' to requirements
- Move TTSServer instantiation into _ensure_tts() so the GPU isn't touched at import time
- Wrap on_generate with @spaces.GPU(duration=120)
- Disable torch.compile (incompatible with ZeroGPU's brief GPU windows)
- Drop hardware:l40s from README frontmatter (Space hardware is now set to zero via API)

Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +24 -12
  3. requirements.txt +1 -0
README.md CHANGED
@@ -10,7 +10,7 @@ pinned: true
10
  license: other
11
  license_name: ltx-2-community
12
  license_link: https://huggingface.co/ResembleAI/Dramabox/blob/main/LICENSE
13
- hardware: l40s
14
  short_description: Expressive TTS with voice cloning β€” DramaBox demo
15
  ---
16
 
 
10
  license: other
11
  license_name: ltx-2-community
12
  license_link: https://huggingface.co/ResembleAI/Dramabox/blob/main/LICENSE
13
+ hf_oauth: false
14
  short_description: Expressive TTS with voice cloning β€” DramaBox demo
15
  ---
16
 
app.py CHANGED
@@ -12,6 +12,7 @@ import tempfile
12
  import time
13
 
14
  import gradio as gr
 
15
 
16
  # Local src import.
17
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
@@ -21,18 +22,27 @@ from model_downloader import get_all_paths # noqa: E402
21
 
22
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
23
  logging.info("Fetching DramaBox checkpoints from HuggingFace (cached after first run)...")
24
- paths = get_all_paths()
25
- logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
26
- tts = TTSServer(
27
- checkpoint=paths["transformer"],
28
- full_checkpoint=paths["audio_components"],
29
- gemma_root=paths["gemma_root"],
30
- device="cuda",
31
- dtype=os.environ.get("LTX_DTYPE", "bf16"),
32
- compile_model=os.environ.get("LTX_COMPILE", "0") == "1",
33
- bnb_4bit=True, # default Gemma is unsloth pre-quantized
34
- )
35
- logging.info("Server ready.")
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  # ── Example prompts (shown as click-to-fill chips in the UI) ─────────────────
@@ -88,10 +98,12 @@ EXAMPLES: list[tuple[str, str]] = [
88
  ]
89
 
90
 
 
91
  def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float, seed: int):
92
  if not prompt or not prompt.strip():
93
  raise gr.Error("Prompt is empty.")
94
  t0 = time.time()
 
95
  ref_path = audio_ref if audio_ref and os.path.exists(str(audio_ref)) else None
96
  output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
97
  tts.generate_to_file(
 
12
  import time
13
 
14
  import gradio as gr
15
+ import spaces
16
 
17
  # Local src import.
18
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
 
22
 
23
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
24
  logging.info("Fetching DramaBox checkpoints from HuggingFace (cached after first run)...")
25
+ PATHS = get_all_paths() # CPU-side download is fine outside the GPU window
26
+
27
+ # Lazy-loaded inside the @spaces.GPU function (no GPU available at import time on ZeroGPU).
28
+ _TTS: TTSServer | None = None
29
+
30
+
31
+ def _ensure_tts() -> TTSServer:
32
+ global _TTS
33
+ if _TTS is None:
34
+ logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
35
+ _TTS = TTSServer(
36
+ checkpoint=PATHS["transformer"],
37
+ full_checkpoint=PATHS["audio_components"],
38
+ gemma_root=PATHS["gemma_root"],
39
+ device="cuda",
40
+ dtype=os.environ.get("LTX_DTYPE", "bf16"),
41
+ compile_model=False, # torch.compile breaks under ZeroGPU's brief GPU windows
42
+ bnb_4bit=True, # unsloth Gemma is pre-quantized
43
+ )
44
+ logging.info("TTSServer ready.")
45
+ return _TTS
46
 
47
 
48
  # ── Example prompts (shown as click-to-fill chips in the UI) ─────────────────
 
98
  ]
99
 
100
 
101
+ @spaces.GPU(duration=120)
102
  def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float, seed: int):
103
  if not prompt or not prompt.strip():
104
  raise gr.Error("Prompt is empty.")
105
  t0 = time.time()
106
+ tts = _ensure_tts()
107
  ref_path = audio_ref if audio_ref and os.path.exists(str(audio_ref)) else None
108
  output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
109
  tts.generate_to_file(
requirements.txt CHANGED
@@ -12,5 +12,6 @@ transformers>=4.45.0
12
  huggingface_hub>=0.20.0,<1.0
13
  bitsandbytes>=0.43.0
14
  gradio>=4.0.0
 
15
  soundfile>=0.12.0
16
  resemble-perth @ git+https://github.com/resemble-ai/Perth.git@master
 
12
  huggingface_hub>=0.20.0,<1.0
13
  bitsandbytes>=0.43.0
14
  gradio>=4.0.0
15
+ spaces>=0.30.0
16
  soundfile>=0.12.0
17
  resemble-perth @ git+https://github.com/resemble-ai/Perth.git@master