Spaces:
Running on Zero
Running on Zero
| """ | |
| Step 4 (optional): Generate ~5-second voice previews with multiple TTS models. | |
| Loads each model sequentially to manage memory, synthesises the first ~5 s | |
| of translated segments, stitches them into a single preview WAV per model, | |
| and returns the file paths. | |
| Currently supports: Chatterbox Multilingual, OmniVoice. | |
| Environment: | |
| TTS_ENGINE: "chatterbox" or "omnivoice" β controls which engine loads. | |
| """ | |
| import gc | |
| import os | |
| import subprocess | |
| from pathlib import Path | |
| import torch | |
| import torchaudio | |
| TTS_ENGINE = os.getenv("TTS_ENGINE", "chatterbox").lower() | |
| import spaces | |
| def _filter_preview_segments(segments: list[dict], max_seconds: float = 30.0) -> list[dict]: | |
| """Return segments whose start time is within the first `max_seconds`.""" | |
| return [s for s in segments if s["start"] < max_seconds] | |
| def _stitch_wavs(wav_paths: list[str], output_path: str) -> str: | |
| """Concatenate multiple WAV files into one using ffmpeg.""" | |
| if not wav_paths: | |
| raise ValueError("No WAVs to stitch") | |
| if len(wav_paths) == 1: | |
| import shutil | |
| shutil.copy(wav_paths[0], output_path) | |
| return output_path | |
| concat_list = output_path + ".concat.txt" | |
| with open(concat_list, "w") as f: | |
| for p in wav_paths: | |
| f.write(f"file '{os.path.abspath(p)}'\n") | |
| cmd = [ | |
| "ffmpeg", "-y", | |
| "-f", "concat", "-safe", "0", | |
| "-i", concat_list, | |
| "-c", "copy", | |
| output_path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| os.remove(concat_list) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"ffmpeg concat failed: {result.stderr[:300]}") | |
| return output_path | |
| def _ensure_browser_wav(path: str) -> str: | |
| """Re-encode a WAV to 16-bit PCM 44100 Hz so browsers can play it.""" | |
| safe_path = path.replace(".wav", "_safe.wav") | |
| cmd = [ | |
| "ffmpeg", "-y", "-i", path, | |
| "-ar", "44100", "-ac", "1", "-sample_fmt", "s16", | |
| "-c:a", "pcm_s16le", | |
| safe_path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode == 0: | |
| os.replace(safe_path, path) | |
| return path | |
| def _clip_audio(path: str, max_sec: float = 10.0) -> str: | |
| """Clip audio to max_sec to prevent excessively slow voice cloning.""" | |
| wav, sr = torchaudio.load(path) | |
| frames = int(max_sec * sr) | |
| if wav.shape[1] > frames: | |
| wav = wav[:, :frames] | |
| out_path = path.replace(".wav", "_clipped.wav") | |
| torchaudio.save(out_path, wav, sr) | |
| return out_path | |
| return path | |
| def _gpu_preview_chatterbox_batch( | |
| segments: list[dict], | |
| ref_audio_clipped: str, | |
| language_id: str, | |
| output_dir: str, | |
| ): | |
| """Load + run Chatterbox preview synthesis inside one GPU scope.""" | |
| from chatterbox.mtl_tts import ChatterboxMultilingualTTS | |
| print(" [preview] Loading Chatterbox in GPU scope...") | |
| model = ChatterboxMultilingualTTS.from_pretrained("cuda") | |
| part_paths = [] | |
| total = len(segments) | |
| for i, seg in enumerate(segments): | |
| text = seg.get("tts_text", seg.get("translated_text", seg["text"])) | |
| out_path = os.path.join(output_dir, f"cb_prev_{i:04d}.wav") | |
| print(f" [preview] Chatterbox: Synthesising segment {i+1}/{total}...") | |
| wav = model.generate( | |
| text[:300], | |
| language_id=language_id, | |
| audio_prompt_path=ref_audio_clipped, | |
| exaggeration=0.5, | |
| temperature=0.8, | |
| cfg_weight=0.5, | |
| ) | |
| torchaudio.save( | |
| out_path, | |
| wav.detach().cpu(), | |
| model.sr, | |
| encoding="PCM_S", | |
| bits_per_sample=16, | |
| ) | |
| part_paths.append(out_path) | |
| return part_paths | |
| # ββ Chatterbox Multilingual preview ββββββββββββββββββββββββββ | |
| def _preview_chatterbox( | |
| segments: list[dict], | |
| reference_audio_path: str, | |
| language_id: str, | |
| output_dir: str, | |
| ): | |
| """Generate a stitched preview WAV using Chatterbox Multilingual.""" | |
| try: | |
| # Clip reference audio to max 10 seconds to prevent weird noise/artifacts | |
| ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=10.0) | |
| device = _get_device() | |
| if device == "cuda": | |
| yield " [preview] Preparing Chatterbox batch preview (device=cuda)...\n" | |
| part_paths = _gpu_preview_chatterbox_batch( | |
| segments=segments, | |
| ref_audio_clipped=ref_audio_clipped, | |
| language_id=language_id, | |
| output_dir=output_dir, | |
| ) | |
| stitched = os.path.join(output_dir, "preview_chatterbox.wav") | |
| _stitch_wavs(part_paths, stitched) | |
| yield " β Chatterbox preview complete\n" | |
| return stitched | |
| yield f" [preview] Preparing Chatterbox Multilingual (device={device})...\n" | |
| from chatterbox.mtl_tts import ChatterboxMultilingualTTS | |
| model = ChatterboxMultilingualTTS.from_pretrained(device) | |
| part_paths = [] | |
| total = len(segments) | |
| for i, seg in enumerate(segments): | |
| yield f" [preview] Chatterbox: Synthesising segment {i+1}/{total}...\n" | |
| text = seg.get("tts_text", seg.get("translated_text", seg["text"])) | |
| out_path = os.path.join(output_dir, f"cb_prev_{i:04d}.wav") | |
| wav = model.generate( | |
| text[:300], | |
| language_id=language_id, | |
| audio_prompt_path=ref_audio_clipped, | |
| exaggeration=0.5, | |
| temperature=0.8, | |
| cfg_weight=0.5, | |
| ) | |
| torchaudio.save(out_path, wav, model.sr, encoding="PCM_S", bits_per_sample=16) | |
| part_paths.append(out_path) | |
| stitched = os.path.join(output_dir, "preview_chatterbox.wav") | |
| _stitch_wavs(part_paths, stitched) | |
| yield " β Chatterbox preview complete\n" | |
| return stitched | |
| except Exception as e: | |
| yield f" β Chatterbox failed: {e}\n" | |
| return None | |
| # ββ OmniVoice preview βββββββββββββββββββββββββββββββββββββββ | |
| _OMNIVOICE_SR = 24000 | |
| def _free_memory(): | |
| """Aggressively release GPU / unified memory.""" | |
| gc.collect() | |
| if torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| elif torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def _get_device() -> str: | |
| if torch.backends.mps.is_available(): | |
| return "mps" | |
| elif torch.cuda.is_available(): | |
| return "cuda" | |
| return "cpu" | |
| def _gpu_preview_omnivoice_segment(model, text, language, ref_audio, ref_text): | |
| return model.generate( | |
| text=text, | |
| language=language, | |
| ref_audio=ref_audio, | |
| ref_text=ref_text, | |
| num_step=32, | |
| speed=1.0, | |
| ) | |
| def _preview_omnivoice( | |
| segments: list[dict], | |
| reference_audio_path: str, | |
| language_id: str, | |
| output_dir: str, | |
| ): | |
| """Generate a stitched preview WAV using OmniVoice.""" | |
| try: | |
| from omnivoice import OmniVoice | |
| import soundfile as sf | |
| device = _get_device() | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| yield f" [preview] Loading OmniVoice on {device} (dtype={dtype})...\n" | |
| model = OmniVoice.from_pretrained( | |
| "k2-fsa/OmniVoice", | |
| device_map=device, | |
| dtype=dtype, | |
| ) | |
| # Clip reference audio to max 10 seconds for speed | |
| ref_clip_sec = 10.0 | |
| ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=ref_clip_sec) | |
| # ref_text must transcribe only what's in ref_audio β otherwise the model | |
| # tries to "finish" the leftover English reference before speaking the target. | |
| ref_text = " ".join( | |
| s["text"] for s in segments if s.get("end", 0.0) <= ref_clip_sec | |
| ).strip()[:500] | |
| part_paths = [] | |
| total = len(segments) | |
| for i, seg in enumerate(segments): | |
| yield f" [preview] OmniVoice: Synthesising segment {i+1}/{total}...\n" | |
| text = seg.get("tts_text", seg.get("translated_text", seg["text"])) | |
| out_path = os.path.join(output_dir, f"ov_prev_{i:04d}.wav") | |
| audio = _gpu_preview_omnivoice_segment( | |
| model=model, | |
| text=text[:300], | |
| language=language_id, | |
| ref_audio=ref_audio_clipped, | |
| ref_text=ref_text, | |
| ) | |
| # model.generate() returns List[np.ndarray] at 24 kHz | |
| sf.write(out_path, audio[0], _OMNIVOICE_SR) | |
| part_paths.append(out_path) | |
| # Unload model | |
| del model | |
| _free_memory() | |
| stitched = os.path.join(output_dir, "preview_omnivoice.wav") | |
| _stitch_wavs(part_paths, stitched) | |
| yield " β OmniVoice preview complete\n" | |
| return stitched | |
| except Exception as e: | |
| yield f" β OmniVoice failed: {e}\n" | |
| _free_memory() | |
| return None | |
| # ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_previews( | |
| segments: list[dict], | |
| reference_audio_path: str, | |
| language_id: str, | |
| output_dir: str = "tmp/audio/previews", | |
| max_preview_seconds: float = 5.0, | |
| ): | |
| """ | |
| Generate ~30 s preview clips as a generator yielding progress messages. | |
| Finally yields a dict containing the result paths. | |
| Only generates preview for the TTS_ENGINE configured for this Space. | |
| """ | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| preview_segs = _filter_preview_segments(segments, max_preview_seconds) | |
| if not preview_segs: | |
| yield " [preview] No segments within preview window β skipping\n" | |
| yield {"__PREVIEW_RESULT__": {TTS_ENGINE: None}} | |
| return | |
| yield f" [preview] Generating preview for {len(preview_segs)} segments (first {max_preview_seconds}s)...\n" | |
| yield f" [preview] Using TTS_ENGINE={TTS_ENGINE}\n" | |
| results: dict[str, str | None] = {} | |
| # Generate preview only for the configured TTS engine | |
| if TTS_ENGINE == "chatterbox": | |
| cb_gen = _preview_chatterbox(preview_segs, reference_audio_path, language_id, output_dir) | |
| try: | |
| while True: | |
| yield next(cb_gen) | |
| except StopIteration as e: | |
| results["chatterbox"] = e.value | |
| elif TTS_ENGINE == "omnivoice": | |
| ov_gen = _preview_omnivoice(preview_segs, reference_audio_path, language_id, output_dir) | |
| try: | |
| while True: | |
| yield next(ov_gen) | |
| except StopIteration as e: | |
| results["omnivoice"] = e.value | |
| yield {"__PREVIEW_RESULT__": results} | |