videovoice / steps /s4_preview.py
Rafii's picture
deploy: switch to chatterbox requirements @ 98aec56
d33ca97
"""
Step 4 (optional): Generate ~5-second voice previews with multiple TTS models.
Loads each model sequentially to manage memory, synthesises the first ~5 s
of translated segments, stitches them into a single preview WAV per model,
and returns the file paths.
Currently supports: Chatterbox Multilingual, OmniVoice.
Environment:
TTS_ENGINE: "chatterbox" or "omnivoice" β€” controls which engine loads.
"""
import gc
import os
import subprocess
from pathlib import Path
import torch
import torchaudio
TTS_ENGINE = os.getenv("TTS_ENGINE", "chatterbox").lower()
import spaces
def _filter_preview_segments(segments: list[dict], max_seconds: float = 30.0) -> list[dict]:
"""Return segments whose start time is within the first `max_seconds`."""
return [s for s in segments if s["start"] < max_seconds]
def _stitch_wavs(wav_paths: list[str], output_path: str) -> str:
"""Concatenate multiple WAV files into one using ffmpeg."""
if not wav_paths:
raise ValueError("No WAVs to stitch")
if len(wav_paths) == 1:
import shutil
shutil.copy(wav_paths[0], output_path)
return output_path
concat_list = output_path + ".concat.txt"
with open(concat_list, "w") as f:
for p in wav_paths:
f.write(f"file '{os.path.abspath(p)}'\n")
cmd = [
"ffmpeg", "-y",
"-f", "concat", "-safe", "0",
"-i", concat_list,
"-c", "copy",
output_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
os.remove(concat_list)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg concat failed: {result.stderr[:300]}")
return output_path
def _ensure_browser_wav(path: str) -> str:
"""Re-encode a WAV to 16-bit PCM 44100 Hz so browsers can play it."""
safe_path = path.replace(".wav", "_safe.wav")
cmd = [
"ffmpeg", "-y", "-i", path,
"-ar", "44100", "-ac", "1", "-sample_fmt", "s16",
"-c:a", "pcm_s16le",
safe_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
os.replace(safe_path, path)
return path
def _clip_audio(path: str, max_sec: float = 10.0) -> str:
"""Clip audio to max_sec to prevent excessively slow voice cloning."""
wav, sr = torchaudio.load(path)
frames = int(max_sec * sr)
if wav.shape[1] > frames:
wav = wav[:, :frames]
out_path = path.replace(".wav", "_clipped.wav")
torchaudio.save(out_path, wav, sr)
return out_path
return path
@spaces.GPU(duration=60)
def _gpu_preview_chatterbox_batch(
segments: list[dict],
ref_audio_clipped: str,
language_id: str,
output_dir: str,
):
"""Load + run Chatterbox preview synthesis inside one GPU scope."""
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
print(" [preview] Loading Chatterbox in GPU scope...")
model = ChatterboxMultilingualTTS.from_pretrained("cuda")
part_paths = []
total = len(segments)
for i, seg in enumerate(segments):
text = seg.get("tts_text", seg.get("translated_text", seg["text"]))
out_path = os.path.join(output_dir, f"cb_prev_{i:04d}.wav")
print(f" [preview] Chatterbox: Synthesising segment {i+1}/{total}...")
wav = model.generate(
text[:300],
language_id=language_id,
audio_prompt_path=ref_audio_clipped,
exaggeration=0.5,
temperature=0.8,
cfg_weight=0.5,
)
torchaudio.save(
out_path,
wav.detach().cpu(),
model.sr,
encoding="PCM_S",
bits_per_sample=16,
)
part_paths.append(out_path)
return part_paths
# ── Chatterbox Multilingual preview ──────────────────────────
def _preview_chatterbox(
segments: list[dict],
reference_audio_path: str,
language_id: str,
output_dir: str,
):
"""Generate a stitched preview WAV using Chatterbox Multilingual."""
try:
# Clip reference audio to max 10 seconds to prevent weird noise/artifacts
ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=10.0)
device = _get_device()
if device == "cuda":
yield " [preview] Preparing Chatterbox batch preview (device=cuda)...\n"
part_paths = _gpu_preview_chatterbox_batch(
segments=segments,
ref_audio_clipped=ref_audio_clipped,
language_id=language_id,
output_dir=output_dir,
)
stitched = os.path.join(output_dir, "preview_chatterbox.wav")
_stitch_wavs(part_paths, stitched)
yield " βœ“ Chatterbox preview complete\n"
return stitched
yield f" [preview] Preparing Chatterbox Multilingual (device={device})...\n"
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
model = ChatterboxMultilingualTTS.from_pretrained(device)
part_paths = []
total = len(segments)
for i, seg in enumerate(segments):
yield f" [preview] Chatterbox: Synthesising segment {i+1}/{total}...\n"
text = seg.get("tts_text", seg.get("translated_text", seg["text"]))
out_path = os.path.join(output_dir, f"cb_prev_{i:04d}.wav")
wav = model.generate(
text[:300],
language_id=language_id,
audio_prompt_path=ref_audio_clipped,
exaggeration=0.5,
temperature=0.8,
cfg_weight=0.5,
)
torchaudio.save(out_path, wav, model.sr, encoding="PCM_S", bits_per_sample=16)
part_paths.append(out_path)
stitched = os.path.join(output_dir, "preview_chatterbox.wav")
_stitch_wavs(part_paths, stitched)
yield " βœ“ Chatterbox preview complete\n"
return stitched
except Exception as e:
yield f" βœ— Chatterbox failed: {e}\n"
return None
# ── OmniVoice preview ───────────────────────────────────────
_OMNIVOICE_SR = 24000
def _free_memory():
"""Aggressively release GPU / unified memory."""
gc.collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
def _get_device() -> str:
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
return "cpu"
@spaces.GPU(duration=30)
def _gpu_preview_omnivoice_segment(model, text, language, ref_audio, ref_text):
return model.generate(
text=text,
language=language,
ref_audio=ref_audio,
ref_text=ref_text,
num_step=32,
speed=1.0,
)
def _preview_omnivoice(
segments: list[dict],
reference_audio_path: str,
language_id: str,
output_dir: str,
):
"""Generate a stitched preview WAV using OmniVoice."""
try:
from omnivoice import OmniVoice
import soundfile as sf
device = _get_device()
dtype = torch.float16 if device == "cuda" else torch.float32
yield f" [preview] Loading OmniVoice on {device} (dtype={dtype})...\n"
model = OmniVoice.from_pretrained(
"k2-fsa/OmniVoice",
device_map=device,
dtype=dtype,
)
# Clip reference audio to max 10 seconds for speed
ref_clip_sec = 10.0
ref_audio_clipped = _clip_audio(reference_audio_path, max_sec=ref_clip_sec)
# ref_text must transcribe only what's in ref_audio β€” otherwise the model
# tries to "finish" the leftover English reference before speaking the target.
ref_text = " ".join(
s["text"] for s in segments if s.get("end", 0.0) <= ref_clip_sec
).strip()[:500]
part_paths = []
total = len(segments)
for i, seg in enumerate(segments):
yield f" [preview] OmniVoice: Synthesising segment {i+1}/{total}...\n"
text = seg.get("tts_text", seg.get("translated_text", seg["text"]))
out_path = os.path.join(output_dir, f"ov_prev_{i:04d}.wav")
audio = _gpu_preview_omnivoice_segment(
model=model,
text=text[:300],
language=language_id,
ref_audio=ref_audio_clipped,
ref_text=ref_text,
)
# model.generate() returns List[np.ndarray] at 24 kHz
sf.write(out_path, audio[0], _OMNIVOICE_SR)
part_paths.append(out_path)
# Unload model
del model
_free_memory()
stitched = os.path.join(output_dir, "preview_omnivoice.wav")
_stitch_wavs(part_paths, stitched)
yield " βœ“ OmniVoice preview complete\n"
return stitched
except Exception as e:
yield f" βœ— OmniVoice failed: {e}\n"
_free_memory()
return None
# ── Public API ───────────────────────────────────────────────
def generate_previews(
segments: list[dict],
reference_audio_path: str,
language_id: str,
output_dir: str = "tmp/audio/previews",
max_preview_seconds: float = 5.0,
):
"""
Generate ~30 s preview clips as a generator yielding progress messages.
Finally yields a dict containing the result paths.
Only generates preview for the TTS_ENGINE configured for this Space.
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
preview_segs = _filter_preview_segments(segments, max_preview_seconds)
if not preview_segs:
yield " [preview] No segments within preview window β€” skipping\n"
yield {"__PREVIEW_RESULT__": {TTS_ENGINE: None}}
return
yield f" [preview] Generating preview for {len(preview_segs)} segments (first {max_preview_seconds}s)...\n"
yield f" [preview] Using TTS_ENGINE={TTS_ENGINE}\n"
results: dict[str, str | None] = {}
# Generate preview only for the configured TTS engine
if TTS_ENGINE == "chatterbox":
cb_gen = _preview_chatterbox(preview_segs, reference_audio_path, language_id, output_dir)
try:
while True:
yield next(cb_gen)
except StopIteration as e:
results["chatterbox"] = e.value
elif TTS_ENGINE == "omnivoice":
ov_gen = _preview_omnivoice(preview_segs, reference_audio_path, language_id, output_dir)
try:
while True:
yield next(ov_gen)
except StopIteration as e:
results["omnivoice"] = e.value
yield {"__PREVIEW_RESULT__": results}