Rafii commited on
Commit
d33ca97
Β·
1 Parent(s): 3474e83

deploy: switch to chatterbox requirements @ 98aec56

Browse files
Files changed (3) hide show
  1. steps/_tts_models.py +0 -34
  2. steps/s4_preview.py +61 -26
  3. steps/s4_tts.py +69 -27
steps/_tts_models.py DELETED
@@ -1,34 +0,0 @@
1
- """Process-cached loaders for TTS models (Chatterbox).
2
-
3
- ZeroGPU best practice: load weights to `cuda` outside `@spaces.GPU` scopes
4
- (via CUDA emulation) so the time-budgeted GPU calls only contain inference.
5
- On Mac/CPU these fall back to MPS or CPU.
6
-
7
- Callers should treat returned models as singletons β€” never call
8
- `del model` or `torch.cuda.empty_cache()` on them between pipeline steps.
9
- """
10
- from __future__ import annotations
11
-
12
- import torch
13
-
14
-
15
- _CHATTERBOX = None
16
-
17
-
18
- def _select_device() -> str:
19
- if torch.cuda.is_available():
20
- return "cuda"
21
- if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
22
- return "mps"
23
- return "cpu"
24
-
25
-
26
- def get_chatterbox():
27
- global _CHATTERBOX
28
- if _CHATTERBOX is None:
29
- from chatterbox.mtl_tts import ChatterboxMultilingualTTS
30
-
31
- device = _select_device()
32
- print(f"[tts] Loading Chatterbox Multilingual on {device}...")
33
- _CHATTERBOX = ChatterboxMultilingualTTS.from_pretrained(device)
34
- return _CHATTERBOX
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
steps/s4_preview.py CHANGED
@@ -20,12 +20,6 @@ import torchaudio
20
 
21
  TTS_ENGINE = os.getenv("TTS_ENGINE", "chatterbox").lower()
22
 
23
- # Conditional imports based on TTS_ENGINE
24
- if TTS_ENGINE == "chatterbox":
25
- from steps._tts_models import get_chatterbox
26
- else:
27
- get_chatterbox = None
28
-
29
  import spaces
30
 
31
 
@@ -90,21 +84,44 @@ def _clip_audio(path: str, max_sec: float = 10.0) -> str:
90
  return path
91
 
92
 
93
- @spaces.GPU(duration=30)
94
- def _gpu_preview_chatterbox_segment(
95
- model,
96
- text: str,
97
  language_id: str,
98
- ref_audio_path: str,
99
  ):
100
- return model.generate(
101
- text[:300],
102
- language_id=language_id,
103
- audio_prompt_path=ref_audio_path,
104
- exaggeration=0.5,
105
- temperature=0.8,
106
- cfg_weight=0.5,
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  # ── Chatterbox Multilingual preview ──────────────────────────
@@ -116,12 +133,28 @@ def _preview_chatterbox(
116
  ):
117
  """Generate a stitched preview WAV using Chatterbox Multilingual."""
118
  try:
119
- yield " [preview] Preparing Chatterbox Multilingual...\n"
120
- model = get_chatterbox()
121
-
122
  # Clip reference audio to max 10 seconds to prevent weird noise/artifacts
123
  ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=10.0)
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  part_paths = []
126
  total = len(segments)
127
  for i, seg in enumerate(segments):
@@ -129,11 +162,13 @@ def _preview_chatterbox(
129
  text = seg.get("tts_text", seg.get("translated_text", seg["text"]))
130
  out_path = os.path.join(output_dir, f"cb_prev_{i:04d}.wav")
131
 
132
- wav = _gpu_preview_chatterbox_segment(
133
- model=model,
134
- text=text,
135
  language_id=language_id,
136
- ref_audio_path=ref_audio_clipped,
 
 
 
137
  )
138
  torchaudio.save(out_path, wav, model.sr, encoding="PCM_S", bits_per_sample=16)
139
  part_paths.append(out_path)
 
20
 
21
  TTS_ENGINE = os.getenv("TTS_ENGINE", "chatterbox").lower()
22
 
 
 
 
 
 
 
23
  import spaces
24
 
25
 
 
84
  return path
85
 
86
 
87
+ @spaces.GPU(duration=60)
88
+ def _gpu_preview_chatterbox_batch(
89
+ segments: list[dict],
90
+ ref_audio_clipped: str,
91
  language_id: str,
92
+ output_dir: str,
93
  ):
94
+ """Load + run Chatterbox preview synthesis inside one GPU scope."""
95
+ from chatterbox.mtl_tts import ChatterboxMultilingualTTS
96
+
97
+ print(" [preview] Loading Chatterbox in GPU scope...")
98
+ model = ChatterboxMultilingualTTS.from_pretrained("cuda")
99
+ part_paths = []
100
+ total = len(segments)
101
+
102
+ for i, seg in enumerate(segments):
103
+ text = seg.get("tts_text", seg.get("translated_text", seg["text"]))
104
+ out_path = os.path.join(output_dir, f"cb_prev_{i:04d}.wav")
105
+
106
+ print(f" [preview] Chatterbox: Synthesising segment {i+1}/{total}...")
107
+ wav = model.generate(
108
+ text[:300],
109
+ language_id=language_id,
110
+ audio_prompt_path=ref_audio_clipped,
111
+ exaggeration=0.5,
112
+ temperature=0.8,
113
+ cfg_weight=0.5,
114
+ )
115
+ torchaudio.save(
116
+ out_path,
117
+ wav.detach().cpu(),
118
+ model.sr,
119
+ encoding="PCM_S",
120
+ bits_per_sample=16,
121
+ )
122
+ part_paths.append(out_path)
123
+
124
+ return part_paths
125
 
126
 
127
  # ── Chatterbox Multilingual preview ──────────────────────────
 
133
  ):
134
  """Generate a stitched preview WAV using Chatterbox Multilingual."""
135
  try:
 
 
 
136
  # Clip reference audio to max 10 seconds to prevent weird noise/artifacts
137
  ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=10.0)
138
 
139
+ device = _get_device()
140
+ if device == "cuda":
141
+ yield " [preview] Preparing Chatterbox batch preview (device=cuda)...\n"
142
+ part_paths = _gpu_preview_chatterbox_batch(
143
+ segments=segments,
144
+ ref_audio_clipped=ref_audio_clipped,
145
+ language_id=language_id,
146
+ output_dir=output_dir,
147
+ )
148
+ stitched = os.path.join(output_dir, "preview_chatterbox.wav")
149
+ _stitch_wavs(part_paths, stitched)
150
+ yield " βœ“ Chatterbox preview complete\n"
151
+ return stitched
152
+
153
+ yield f" [preview] Preparing Chatterbox Multilingual (device={device})...\n"
154
+ from chatterbox.mtl_tts import ChatterboxMultilingualTTS
155
+
156
+ model = ChatterboxMultilingualTTS.from_pretrained(device)
157
+
158
  part_paths = []
159
  total = len(segments)
160
  for i, seg in enumerate(segments):
 
162
  text = seg.get("tts_text", seg.get("translated_text", seg["text"]))
163
  out_path = os.path.join(output_dir, f"cb_prev_{i:04d}.wav")
164
 
165
+ wav = model.generate(
166
+ text[:300],
 
167
  language_id=language_id,
168
+ audio_prompt_path=ref_audio_clipped,
169
+ exaggeration=0.5,
170
+ temperature=0.8,
171
+ cfg_weight=0.5,
172
  )
173
  torchaudio.save(out_path, wav, model.sr, encoding="PCM_S", bits_per_sample=16)
174
  part_paths.append(out_path)
steps/s4_tts.py CHANGED
@@ -20,31 +20,56 @@ from tqdm import tqdm
20
 
21
  TTS_ENGINE = os.getenv("TTS_ENGINE", "chatterbox").lower()
22
 
23
- # Conditional imports based on TTS_ENGINE
24
- if TTS_ENGINE == "chatterbox":
25
- from steps._tts_models import get_chatterbox
26
- else:
27
- # OmniVoice mode - chatterbox imports not needed
28
- get_chatterbox = None
29
-
30
  import spaces
31
 
32
 
33
- @spaces.GPU(duration=60)
34
- def _gpu_chatterbox_generate(
35
- model,
36
- text: str,
37
  language_id: str,
38
- ref_audio_path: str,
39
  ):
40
- return model.generate(
41
- text[:300],
42
- language_id=language_id,
43
- audio_prompt_path=ref_audio_path,
44
- exaggeration=0.5,
45
- temperature=0.8,
46
- cfg_weight=0.5,
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  # ── Chatterbox Multilingual ─────────────────────────────────
@@ -54,12 +79,27 @@ def _synthesise_chatterbox(
54
  language_id: str,
55
  output_dir: str,
56
  ):
57
- yield " [s4] Preparing Chatterbox Multilingual TTS...\n"
58
- model = get_chatterbox()
59
-
60
  # Clip reference audio to max 10 seconds to prevent weird noise/artifacts
61
  ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=15.0)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  results = []
64
  total = len(segments)
65
  for i, seg in enumerate(segments):
@@ -71,11 +111,13 @@ def _synthesise_chatterbox(
71
  max_tokens = min(1000, max(150, int(orig_dur * 75 * 1.5)))
72
  _ = max_tokens
73
 
74
- wav = _gpu_chatterbox_generate(
75
- model=model,
76
- text=text,
77
  language_id=language_id,
78
- ref_audio_path=ref_audio_clipped,
 
 
 
79
  )
80
 
81
  wav = _trim_trailing_noise(wav, model.sr)
 
20
 
21
  TTS_ENGINE = os.getenv("TTS_ENGINE", "chatterbox").lower()
22
 
 
 
 
 
 
 
 
23
  import spaces
24
 
25
 
26
+ @spaces.GPU(duration=120)
27
+ def _gpu_chatterbox_full_batch(
28
+ segments: list[dict],
29
+ ref_audio_clipped: str,
30
  language_id: str,
31
+ output_dir: str,
32
  ):
33
+ """
34
+ Load + run Chatterbox inside a single GPU-decorated scope.
35
+
36
+ ZeroGPU only intercepts CUDA init while the decorated function is active,
37
+ so constructing the CUDA model here avoids low-level torch CUDA init errors.
38
+ """
39
+ from chatterbox.mtl_tts import ChatterboxMultilingualTTS
40
+
41
+ print(" [s4] Loading Chatterbox in GPU scope...")
42
+ model = ChatterboxMultilingualTTS.from_pretrained("cuda")
43
+ results = []
44
+ total = len(segments)
45
+
46
+ for i, seg in enumerate(segments):
47
+ text = seg.get("tts_text", seg.get("translated_text", seg["text"]))
48
+ out_path = os.path.join(output_dir, f"seg_{i:04d}.wav")
49
+ orig_dur = seg["end"] - seg["start"]
50
+
51
+ print(f" [s4] Chatterbox: Synthesising segment {i+1}/{total}...")
52
+ wav = model.generate(
53
+ text[:300],
54
+ language_id=language_id,
55
+ audio_prompt_path=ref_audio_clipped,
56
+ exaggeration=0.5,
57
+ temperature=0.8,
58
+ cfg_weight=0.5,
59
+ )
60
+
61
+ wav = _trim_trailing_noise(wav, model.sr)
62
+ wav = _trim_to_duration(wav, model.sr, orig_dur)
63
+ torchaudio.save(
64
+ out_path,
65
+ wav.detach().cpu(),
66
+ model.sr,
67
+ encoding="PCM_S",
68
+ bits_per_sample=16,
69
+ )
70
+ results.append({**seg, "tts_path": out_path})
71
+
72
+ return results
73
 
74
 
75
  # ── Chatterbox Multilingual ─────────────────────────────────
 
79
  language_id: str,
80
  output_dir: str,
81
  ):
 
 
 
82
  # Clip reference audio to max 10 seconds to prevent weird noise/artifacts
83
  ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=15.0)
84
 
85
+ device = _get_device()
86
+ if device == "cuda":
87
+ yield " [s4] Preparing Chatterbox batch processing (device=cuda)...\n"
88
+ results = _gpu_chatterbox_full_batch(
89
+ segments=segments,
90
+ ref_audio_clipped=ref_audio_clipped,
91
+ language_id=language_id,
92
+ output_dir=output_dir,
93
+ )
94
+ yield f" [s4] Chatterbox TTS complete β€” {len(results)} segments synthesised βœ“\n"
95
+ yield {"__TTS_RESULT__": results}
96
+ return
97
+
98
+ yield f" [s4] Preparing Chatterbox Multilingual TTS (device={device})...\n"
99
+ from chatterbox.mtl_tts import ChatterboxMultilingualTTS
100
+
101
+ model = ChatterboxMultilingualTTS.from_pretrained(device)
102
+
103
  results = []
104
  total = len(segments)
105
  for i, seg in enumerate(segments):
 
111
  max_tokens = min(1000, max(150, int(orig_dur * 75 * 1.5)))
112
  _ = max_tokens
113
 
114
+ wav = model.generate(
115
+ text[:300],
 
116
  language_id=language_id,
117
+ audio_prompt_path=ref_audio_clipped,
118
+ exaggeration=0.5,
119
+ temperature=0.8,
120
+ cfg_weight=0.5,
121
  )
122
 
123
  wav = _trim_trailing_noise(wav, model.sr)