interviewbot / backend /services /tts_service.py
sajith-0701's picture
v3.1
5837391
import asyncio
import os
import tempfile
from typing import Tuple
from collections import OrderedDict
from functools import wraps
_MODEL_CACHE = {}
_MODEL_LOCK = asyncio.Lock()
_AUDIO_CACHE = OrderedDict()
_AUDIO_CACHE_LOCK = asyncio.Lock()
_SYNTHESIZE_LOCK = asyncio.Lock()
_TORCH_LOAD_PATCHED = False
XTTS_MODEL = "tts_models/multilingual/multi-dataset/xtts_v2"
XTTS_LANGUAGE = "en"
XTTS_SPEED = 1.2
_XTTS_WARM = False
_XTTS_LAST_ERROR: str | None = None
AUDIO_CACHE_MAX_ITEMS = 300
def _resolve_xtts_max_text_length() -> int:
"""0 disables truncation so full question text is spoken."""
try:
return max(0, int(os.getenv("XTTS_MAX_TEXT_LENGTH", "0")))
except Exception:
return 0
XTTS_MAX_TEXT_LENGTH = _resolve_xtts_max_text_length()
# User-approved stable voices:
# - Female: index 45 => Alexandra Hisakawa
# - Male: index 21 => Abrahan Mack
XTTS_SPEAKER_BY_GENDER = {
"female": "Alexandra Hisakawa",
"male": "Abrahan Mack",
"auto": "Alexandra Hisakawa",
}
def _resolve_xtts_checkpoint_trust() -> bool:
"""Enable trusted local checkpoint loading compatibility by default."""
value = os.getenv("XTTS_TRUSTED_CHECKPOINTS", "1").strip().lower()
return value in {"1", "true", "yes", "on"}
def _ensure_torch_load_compat_for_xtts() -> None:
"""Patch torch.load default for PyTorch 2.6+ when loading trusted XTTS checkpoints."""
global _TORCH_LOAD_PATCHED
if _TORCH_LOAD_PATCHED or not _resolve_xtts_checkpoint_trust():
return
try:
import torch
except Exception:
return
original_load = getattr(torch, "load", None)
if not callable(original_load):
return
@wraps(original_load)
def _torch_load_compat(*args, **kwargs):
# Coqui XTTS checkpoints require full object unpickling on newer PyTorch.
kwargs.setdefault("weights_only", False)
return original_load(*args, **kwargs)
torch.load = _torch_load_compat
_TORCH_LOAD_PATCHED = True
def _select_model(voice_gender: str) -> Tuple[str, str | None]:
gender = (voice_gender or "female").strip().lower()
if gender == "male":
# Multi-speaker model; use a male VCTK speaker token.
return "tts_models/en/vctk/vits", "p226"
# Default female-like English voice model.
return "tts_models/en/ljspeech/tacotron2-DDC", None
async def _get_tts_model(model_name: str):
async with _MODEL_LOCK:
if model_name in _MODEL_CACHE:
return _MODEL_CACHE[model_name]
def _load_model():
_ensure_torch_load_compat_for_xtts()
try:
from TTS.api import TTS
except Exception as exc:
raise RuntimeError(
"Coqui TTS is not installed in the active Python environment"
) from exc
gpu_pref = os.getenv("XTTS_USE_GPU", "auto").strip().lower()
use_gpu = False
if gpu_pref in {"1", "true", "yes", "on"}:
use_gpu = True
elif gpu_pref in {"0", "false", "no", "off"}:
use_gpu = False
else:
try:
import torch
use_gpu = bool(torch.cuda.is_available())
except Exception:
use_gpu = False
# TTS(..., gpu=...) is deprecated upstream. Load once, then move model.
tts = TTS(model_name=model_name, progress_bar=False)
if use_gpu:
try:
tts.to("cuda")
return tts
except Exception:
# Graceful CPU fallback when CUDA runtime is unavailable/mismatched.
try:
tts.to("cpu")
except Exception:
pass
return tts
try:
tts.to("cpu")
except Exception:
pass
return tts
model = await asyncio.to_thread(_load_model)
_MODEL_CACHE[model_name] = model
return model
def _resolve_xtts_speaker(voice_gender: str) -> str:
gender = (voice_gender or "female").strip().lower()
if gender not in XTTS_SPEAKER_BY_GENDER:
gender = "female"
return XTTS_SPEAKER_BY_GENDER[gender]
def _normalize_text_for_speech(value: str, max_length: int = XTTS_MAX_TEXT_LENGTH) -> str:
content = " ".join((value or "").strip().split())
if max_length <= 0:
return content
if len(content) <= max_length:
return content
trimmed = content[:max_length].rstrip()
# Keep sentence boundaries cleaner when truncating.
for marker in ("?", "!", "."):
if marker in trimmed:
head = trimmed.rsplit(marker, 1)[0].strip()
if len(head) >= max_length // 2:
return f"{head}{marker}"
return trimmed
async def warmup_xtts_model() -> bool:
"""Preload XTTS to avoid long cold-start on first interview question."""
global _XTTS_WARM, _XTTS_LAST_ERROR
if _XTTS_WARM:
return True
try:
await _get_tts_model(XTTS_MODEL)
_XTTS_WARM = True
_XTTS_LAST_ERROR = None
return True
except Exception as exc:
# Keep API startup resilient; routes decide whether to surface this.
_XTTS_LAST_ERROR = str(exc)
return False
def get_xtts_warmup_state() -> dict:
return {
"is_warm": _XTTS_WARM,
"last_error": _XTTS_LAST_ERROR,
}
def _synthesize_xtts_to_file(tts, text: str, speaker: str, file_path: str) -> None:
kwargs = {
"text": text,
"file_path": file_path,
"speaker": speaker,
"language": XTTS_LANGUAGE,
}
try:
# Faster delivery for interview prompts.
tts.tts_to_file(**kwargs, speed=XTTS_SPEED)
except TypeError:
# Some model/runtime combinations may not expose speed arg.
tts.tts_to_file(**kwargs)
def _build_audio_cache_key(text: str, voice_gender: str) -> str:
return f"{(voice_gender or 'female').strip().lower()}::{text.strip()}"
async def _get_cached_audio(cache_key: str) -> bytes | None:
async with _AUDIO_CACHE_LOCK:
value = _AUDIO_CACHE.get(cache_key)
if value is None:
return None
# LRU touch.
_AUDIO_CACHE.move_to_end(cache_key)
return value
async def _set_cached_audio(cache_key: str, data: bytes) -> None:
async with _AUDIO_CACHE_LOCK:
_AUDIO_CACHE[cache_key] = data
_AUDIO_CACHE.move_to_end(cache_key)
while len(_AUDIO_CACHE) > AUDIO_CACHE_MAX_ITEMS:
_AUDIO_CACHE.popitem(last=False)
async def _synthesize_fallback_wav(text: str, voice_gender: str) -> bytes:
model_name, speaker = _select_model(voice_gender)
tts = await _get_tts_model(model_name)
fd, tmp_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
try:
def _synthesize():
kwargs = {
"text": text,
"file_path": tmp_path,
}
if speaker:
kwargs["speaker"] = speaker
tts.tts_to_file(**kwargs)
await asyncio.to_thread(_synthesize)
with open(tmp_path, "rb") as f:
return f.read()
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
async def prefetch_wav(text: str, voice_gender: str = "female") -> None:
"""Best-effort speech prefetch to warm audio cache."""
try:
await synthesize_wav(text, voice_gender)
except Exception:
# Silent prefetch failure; runtime synth may still succeed later.
pass
async def synthesize_wav(text: str, voice_gender: str = "female") -> bytes:
content = _normalize_text_for_speech(text)
if not content:
raise ValueError("text is required")
normalized_gender = (voice_gender or "female").strip().lower()
if normalized_gender not in {"male", "female", "auto"}:
normalized_gender = "female"
cache_key = _build_audio_cache_key(content, normalized_gender)
cached = await _get_cached_audio(cache_key)
if cached:
return cached
async with _SYNTHESIZE_LOCK:
# Recheck cache after waiting for lock in case another request already synthesized it.
cached = await _get_cached_audio(cache_key)
if cached:
return cached
speaker = _resolve_xtts_speaker(normalized_gender)
tts = await _get_tts_model(XTTS_MODEL)
fd, tmp_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
try:
def _synthesize():
_synthesize_xtts_to_file(tts, text=content, speaker=speaker, file_path=tmp_path)
try:
await asyncio.to_thread(_synthesize)
with open(tmp_path, "rb") as f:
wav = f.read()
await _set_cached_audio(cache_key, wav)
return wav
except Exception:
# Keep speech available even if XTTS runtime has temporary issues.
wav = await _synthesize_fallback_wav(content, normalized_gender)
await _set_cached_audio(cache_key, wav)
return wav
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)