""" Step 4 (optional): Generate ~5-second voice previews with multiple TTS models. Loads each model sequentially to manage memory, synthesises the first ~5 s of translated segments, stitches them into a single preview WAV per model, and returns the file paths. Currently supports: Chatterbox Multilingual, OmniVoice. Environment: TTS_ENGINE: "chatterbox" or "omnivoice" — controls which engine loads. """ import gc import os import subprocess from pathlib import Path import torch import torchaudio TTS_ENGINE = os.getenv("TTS_ENGINE", "chatterbox").lower() import spaces def _filter_preview_segments(segments: list[dict], max_seconds: float = 30.0) -> list[dict]: """Return segments whose start time is within the first `max_seconds`.""" return [s for s in segments if s["start"] < max_seconds] def _stitch_wavs(wav_paths: list[str], output_path: str) -> str: """Concatenate multiple WAV files into one using ffmpeg.""" if not wav_paths: raise ValueError("No WAVs to stitch") if len(wav_paths) == 1: import shutil shutil.copy(wav_paths[0], output_path) return output_path concat_list = output_path + ".concat.txt" with open(concat_list, "w") as f: for p in wav_paths: f.write(f"file '{os.path.abspath(p)}'\n") cmd = [ "ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", concat_list, "-c", "copy", output_path, ] result = subprocess.run(cmd, capture_output=True, text=True) os.remove(concat_list) if result.returncode != 0: raise RuntimeError(f"ffmpeg concat failed: {result.stderr[:300]}") return output_path def _ensure_browser_wav(path: str) -> str: """Re-encode a WAV to 16-bit PCM 44100 Hz so browsers can play it.""" safe_path = path.replace(".wav", "_safe.wav") cmd = [ "ffmpeg", "-y", "-i", path, "-ar", "44100", "-ac", "1", "-sample_fmt", "s16", "-c:a", "pcm_s16le", safe_path, ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode == 0: os.replace(safe_path, path) return path def _clip_audio(path: str, max_sec: float = 10.0) -> str: """Clip audio to max_sec to prevent excessively slow voice cloning.""" wav, sr = torchaudio.load(path) frames = int(max_sec * sr) if wav.shape[1] > frames: wav = wav[:, :frames] out_path = path.replace(".wav", "_clipped.wav") torchaudio.save(out_path, wav, sr) return out_path return path @spaces.GPU(duration=60) def _gpu_preview_chatterbox_batch( segments: list[dict], ref_audio_clipped: str, language_id: str, output_dir: str, ): """Load + run Chatterbox preview synthesis inside one GPU scope.""" from chatterbox.mtl_tts import ChatterboxMultilingualTTS print(" [preview] Loading Chatterbox in GPU scope...") model = ChatterboxMultilingualTTS.from_pretrained("cuda") part_paths = [] total = len(segments) for i, seg in enumerate(segments): text = seg.get("tts_text", seg.get("translated_text", seg["text"])) out_path = os.path.join(output_dir, f"cb_prev_{i:04d}.wav") print(f" [preview] Chatterbox: Synthesising segment {i+1}/{total}...") wav = model.generate( text[:300], language_id=language_id, audio_prompt_path=ref_audio_clipped, exaggeration=0.5, temperature=0.8, cfg_weight=0.5, ) torchaudio.save( out_path, wav.detach().cpu(), model.sr, encoding="PCM_S", bits_per_sample=16, ) part_paths.append(out_path) return part_paths # ── Chatterbox Multilingual preview ────────────────────────── def _preview_chatterbox( segments: list[dict], reference_audio_path: str, language_id: str, output_dir: str, ): """Generate a stitched preview WAV using Chatterbox Multilingual.""" try: # Clip reference audio to max 10 seconds to prevent weird noise/artifacts ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=10.0) device = _get_device() if device == "cuda": yield " [preview] Preparing Chatterbox batch preview (device=cuda)...\n" part_paths = _gpu_preview_chatterbox_batch( segments=segments, ref_audio_clipped=ref_audio_clipped, language_id=language_id, output_dir=output_dir, ) stitched = os.path.join(output_dir, "preview_chatterbox.wav") _stitch_wavs(part_paths, stitched) yield " ✓ Chatterbox preview complete\n" return stitched yield f" [preview] Preparing Chatterbox Multilingual (device={device})...\n" from chatterbox.mtl_tts import ChatterboxMultilingualTTS model = ChatterboxMultilingualTTS.from_pretrained(device) part_paths = [] total = len(segments) for i, seg in enumerate(segments): yield f" [preview] Chatterbox: Synthesising segment {i+1}/{total}...\n" text = seg.get("tts_text", seg.get("translated_text", seg["text"])) out_path = os.path.join(output_dir, f"cb_prev_{i:04d}.wav") wav = model.generate( text[:300], language_id=language_id, audio_prompt_path=ref_audio_clipped, exaggeration=0.5, temperature=0.8, cfg_weight=0.5, ) torchaudio.save(out_path, wav, model.sr, encoding="PCM_S", bits_per_sample=16) part_paths.append(out_path) stitched = os.path.join(output_dir, "preview_chatterbox.wav") _stitch_wavs(part_paths, stitched) yield " ✓ Chatterbox preview complete\n" return stitched except Exception as e: yield f" ✗ Chatterbox failed: {e}\n" return None # ── OmniVoice preview ─────────────────────────────────────── _OMNIVOICE_SR = 24000 def _free_memory(): """Aggressively release GPU / unified memory.""" gc.collect() if torch.backends.mps.is_available(): torch.mps.empty_cache() elif torch.cuda.is_available(): torch.cuda.empty_cache() def _get_device() -> str: if torch.backends.mps.is_available(): return "mps" elif torch.cuda.is_available(): return "cuda" return "cpu" @spaces.GPU(duration=30) def _gpu_preview_omnivoice_segment(model, text, language, ref_audio, ref_text): return model.generate( text=text, language=language, ref_audio=ref_audio, ref_text=ref_text, num_step=32, speed=1.0, ) def _preview_omnivoice( segments: list[dict], reference_audio_path: str, language_id: str, output_dir: str, ): """Generate a stitched preview WAV using OmniVoice.""" try: from omnivoice import OmniVoice import soundfile as sf device = _get_device() dtype = torch.float16 if device == "cuda" else torch.float32 yield f" [preview] Loading OmniVoice on {device} (dtype={dtype})...\n" model = OmniVoice.from_pretrained( "k2-fsa/OmniVoice", device_map=device, dtype=dtype, ) # Clip reference audio to max 10 seconds for speed ref_clip_sec = 10.0 ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=ref_clip_sec) # ref_text must transcribe only what's in ref_audio — otherwise the model # tries to "finish" the leftover English reference before speaking the target. ref_text = " ".join( s["text"] for s in segments if s.get("end", 0.0) <= ref_clip_sec ).strip()[:500] part_paths = [] total = len(segments) for i, seg in enumerate(segments): yield f" [preview] OmniVoice: Synthesising segment {i+1}/{total}...\n" text = seg.get("tts_text", seg.get("translated_text", seg["text"])) out_path = os.path.join(output_dir, f"ov_prev_{i:04d}.wav") audio = _gpu_preview_omnivoice_segment( model=model, text=text[:300], language=language_id, ref_audio=ref_audio_clipped, ref_text=ref_text, ) # model.generate() returns List[np.ndarray] at 24 kHz sf.write(out_path, audio[0], _OMNIVOICE_SR) part_paths.append(out_path) # Unload model del model _free_memory() stitched = os.path.join(output_dir, "preview_omnivoice.wav") _stitch_wavs(part_paths, stitched) yield " ✓ OmniVoice preview complete\n" return stitched except Exception as e: yield f" ✗ OmniVoice failed: {e}\n" _free_memory() return None # ── Public API ─────────────────────────────────────────────── def generate_previews( segments: list[dict], reference_audio_path: str, language_id: str, output_dir: str = "tmp/audio/previews", max_preview_seconds: float = 5.0, ): """ Generate ~30 s preview clips as a generator yielding progress messages. Finally yields a dict containing the result paths. Only generates preview for the TTS_ENGINE configured for this Space. """ Path(output_dir).mkdir(parents=True, exist_ok=True) preview_segs = _filter_preview_segments(segments, max_preview_seconds) if not preview_segs: yield " [preview] No segments within preview window — skipping\n" yield {"__PREVIEW_RESULT__": {TTS_ENGINE: None}} return yield f" [preview] Generating preview for {len(preview_segs)} segments (first {max_preview_seconds}s)...\n" yield f" [preview] Using TTS_ENGINE={TTS_ENGINE}\n" results: dict[str, str | None] = {} # Generate preview only for the configured TTS engine if TTS_ENGINE == "chatterbox": cb_gen = _preview_chatterbox(preview_segs, reference_audio_path, language_id, output_dir) try: while True: yield next(cb_gen) except StopIteration as e: results["chatterbox"] = e.value elif TTS_ENGINE == "omnivoice": ov_gen = _preview_omnivoice(preview_segs, reference_audio_path, language_id, output_dir) try: while True: yield next(ov_gen) except StopIteration as e: results["omnivoice"] = e.value yield {"__PREVIEW_RESULT__": results}