sajith-0701 commited on
Commit
0eda9c2
·
1 Parent(s): d50ee26

XTTS and Whisper are initialized

.gitignore CHANGED
@@ -11,4 +11,5 @@ dist
11
  inter
12
  Resume.pdf
13
  LANGGRAPH_AND_TOOLS.md
14
- WORKFLOW.md
 
 
11
  inter
12
  Resume.pdf
13
  LANGGRAPH_AND_TOOLS.md
14
+ WORKFLOW.md
15
+ voice_name_list_xtts.txt
backend/main.py CHANGED
@@ -1,4 +1,5 @@
1
  from contextlib import asynccontextmanager
 
2
  from fastapi import FastAPI
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.staticfiles import StaticFiles
@@ -8,6 +9,8 @@ import os
8
 
9
  from config import get_settings
10
  from database import connect_db, close_db
 
 
11
 
12
  from routers import auth, resume, profile, interview, reports, admin, speech
13
 
@@ -19,6 +22,17 @@ async def lifespan(app: FastAPI):
19
  # Startup
20
  await connect_db()
21
  os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
22
  print(f"🚀 Interview Bot API running in {settings.APP_ENV} mode")
23
  yield
24
  # Shutdown
 
1
  from contextlib import asynccontextmanager
2
+ import asyncio
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.staticfiles import StaticFiles
 
9
 
10
  from config import get_settings
11
  from database import connect_db, close_db
12
+ from services.tts_service import warmup_xtts_model
13
+ from services.stt_service import warmup_whisper_model
14
 
15
  from routers import auth, resume, profile, interview, reports, admin, speech
16
 
 
22
  # Startup
23
  await connect_db()
24
  os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
25
+ try:
26
+ await asyncio.wait_for(warmup_xtts_model(), timeout=45)
27
+ print("XTTS warmup: ready")
28
+ except Exception as exc:
29
+ print(f"XTTS warmup skipped: {exc}")
30
+
31
+ try:
32
+ await asyncio.wait_for(warmup_whisper_model(), timeout=45)
33
+ print("Whisper warmup: ready")
34
+ except Exception as exc:
35
+ print(f"Whisper warmup skipped: {exc}")
36
  print(f"🚀 Interview Bot API running in {settings.APP_ENV} mode")
37
  yield
38
  # Shutdown
backend/requirements.txt CHANGED
@@ -16,3 +16,4 @@ python-dotenv==1.0.1
16
  aiofiles==24.1.0
17
  pypdf==5.4.0
18
  python-docx==1.1.2
 
 
16
  aiofiles==24.1.0
17
  pypdf==5.4.0
18
  python-docx==1.1.2
19
+ faster-whisper==1.0.3
backend/routers/speech.py CHANGED
@@ -1,9 +1,10 @@
1
- from fastapi import APIRouter, Depends, HTTPException
2
  from fastapi.responses import Response
3
  from pydantic import BaseModel
4
 
5
  from auth.jwt import get_current_user
6
- from services.tts_service import synthesize_wav
 
7
 
8
  router = APIRouter()
9
 
@@ -19,6 +20,14 @@ async def speech_health(current_user: dict = Depends(get_current_user)):
19
  return {"status": "ok", "service": "speech"}
20
 
21
 
 
 
 
 
 
 
 
 
22
  @router.post("/synthesize")
23
  async def synthesize_speech(
24
  request: SpeechSynthesisRequest,
@@ -34,3 +43,26 @@ async def synthesize_speech(
34
  raise HTTPException(status_code=503, detail=str(e))
35
  except Exception as e:
36
  raise HTTPException(status_code=500, detail=f"Speech synthesis failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
2
  from fastapi.responses import Response
3
  from pydantic import BaseModel
4
 
5
  from auth.jwt import get_current_user
6
+ from services.tts_service import synthesize_wav, warmup_xtts_model
7
+ from services.stt_service import transcribe_audio_bytes, warmup_whisper_model
8
 
9
  router = APIRouter()
10
 
 
20
  return {"status": "ok", "service": "speech"}
21
 
22
 
23
+ @router.post("/warmup")
24
+ async def speech_warmup(current_user: dict = Depends(get_current_user)):
25
+ """Warm XTTS model so first interview playback does not hit cold-start delay."""
26
+ await warmup_xtts_model()
27
+ await warmup_whisper_model()
28
+ return {"status": "ok", "message": "speech model warmed"}
29
+
30
+
31
  @router.post("/synthesize")
32
  async def synthesize_speech(
33
  request: SpeechSynthesisRequest,
 
43
  raise HTTPException(status_code=503, detail=str(e))
44
  except Exception as e:
45
  raise HTTPException(status_code=500, detail=f"Speech synthesis failed: {str(e)}")
46
+
47
+
48
+ @router.post("/transcribe")
49
+ async def transcribe_speech(
50
+ audio: UploadFile = File(...),
51
+ language: str = Form("en"),
52
+ current_user: dict = Depends(get_current_user),
53
+ ):
54
+ """Transcribe uploaded interview audio using Whisper model."""
55
+ try:
56
+ payload = await audio.read()
57
+ text = await transcribe_audio_bytes(
58
+ audio_bytes=payload,
59
+ filename=audio.filename or "speech.webm",
60
+ language=language,
61
+ )
62
+ return {"text": text}
63
+ except ValueError as e:
64
+ raise HTTPException(status_code=400, detail=str(e))
65
+ except RuntimeError as e:
66
+ raise HTTPException(status_code=503, detail=str(e))
67
+ except Exception as e:
68
+ raise HTTPException(status_code=500, detail=f"Speech transcription failed: {str(e)}")
backend/services/interview_service.py CHANGED
@@ -2,12 +2,13 @@ import json
2
  import asyncio
3
  from bson import ObjectId
4
  from database import get_db, get_redis
5
- from models.collections import SESSIONS, JOB_ROLES, SKILLS, QUESTIONS, TOPICS, TOPIC_QUESTIONS, ROLE_REQUIREMENTS, RESUMES
6
  from utils.helpers import generate_id, utc_now, str_objectid
7
  from utils.skills import normalize_skill_list, find_matching_skills, find_missing_skills, build_interview_focus_skills
8
  from services.interview_graph import run_interview_graph
9
  from utils.gemini import generate_interview_question_batch, analyze_resume_vs_job_description
10
  from services.job_description_service import get_job_description_for_user
 
11
 
12
  MAX_QUESTIONS = 20
13
  SESSION_TTL = 7200 # 2 hours
@@ -43,6 +44,31 @@ def _safe_int(value, default: int = 0) -> int:
43
  return default
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def _normalize_bank_difficulty(value: str) -> str:
47
  difficulty = (value or "medium").strip().lower()
48
  if difficulty not in {"easy", "medium", "hard"}:
@@ -418,6 +444,13 @@ async def _start_topic_interview(user_id: str, topic_id: str) -> dict:
418
  session_id = generate_id()
419
  _LOCAL_SUMMARIES[session_id] = ""
420
 
 
 
 
 
 
 
 
421
  session_doc = {
422
  "session_id": session_id,
423
  "user_id": user_id,
@@ -434,6 +467,7 @@ async def _start_topic_interview(user_id: str, topic_id: str) -> dict:
434
  "metrics_bank_questions": 0,
435
  "metrics_bank_shortfall": 0,
436
  "metrics_generation_batches": 0,
 
437
  "timer_enabled": timer_enabled,
438
  "timer_seconds": timer_seconds,
439
  "started_at": utc_now(),
@@ -459,6 +493,7 @@ async def _start_topic_interview(user_id: str, topic_id: str) -> dict:
459
  "timer_enabled": str(timer_enabled),
460
  "timer_seconds": str(timer_seconds or ""),
461
  "status": "in_progress",
 
462
  "metrics_gemini_calls": 0,
463
  "metrics_gemini_questions": 0,
464
  "metrics_bank_questions": 0,
@@ -492,6 +527,14 @@ async def _start_topic_interview(user_id: str, topic_id: str) -> dict:
492
  await redis.expire(f"session:{session_id}:pending_questions", SESSION_TTL)
493
 
494
  first_q_data = await redis.hgetall(f"session:{session_id}:q:{first_id}")
 
 
 
 
 
 
 
 
495
  return {
496
  "session_id": session_id,
497
  "interview_type": "topic",
@@ -598,6 +641,13 @@ async def start_interview(
598
  db = get_db()
599
  redis = get_redis()
600
 
 
 
 
 
 
 
 
601
  # Get user skills
602
  skills_doc = await db[SKILLS].find_one({"user_id": user_id})
603
  user_skills = skills_doc.get("skills", ["general"]) if skills_doc else ["general"]
@@ -680,6 +730,7 @@ async def start_interview(
680
  "metrics_bank_questions": initial_bank_questions,
681
  "metrics_bank_shortfall": initial_bank_shortfall,
682
  "metrics_generation_batches": 1,
 
683
  "started_at": utc_now(),
684
  }
685
  await db[SESSIONS].insert_one(session_doc)
@@ -702,6 +753,7 @@ async def start_interview(
702
  "current_difficulty": last_difficulty,
703
  "interview_type": "resume",
704
  "status": "in_progress",
 
705
  "metrics_gemini_calls": initial_gemini_calls,
706
  "metrics_gemini_questions": initial_gemini_questions,
707
  "metrics_bank_questions": initial_bank_questions,
@@ -720,6 +772,13 @@ async def start_interview(
720
  await redis.expire(f"session:{session_id}:pending_questions", SESSION_TTL)
721
 
722
  first_q_data = await redis.hgetall(f"session:{session_id}:q:{first_id}")
 
 
 
 
 
 
 
723
 
724
  return {
725
  "session_id": session_id,
@@ -885,6 +944,15 @@ async def submit_answer(session_id: str, question_id: str, answer: str) -> dict:
885
  raise ValueError("Unable to fetch or generate next question")
886
 
887
  q_data = await redis.hgetall(f"session:{session_id}:q:{next_question_id}")
 
 
 
 
 
 
 
 
 
888
  next_difficulty = q_data.get("difficulty", session.get("current_difficulty", "medium"))
889
  new_count = question_count + 1
890
  new_served_count = served_count + 1
 
2
  import asyncio
3
  from bson import ObjectId
4
  from database import get_db, get_redis
5
+ from models.collections import SESSIONS, USERS, JOB_ROLES, SKILLS, QUESTIONS, TOPICS, TOPIC_QUESTIONS, ROLE_REQUIREMENTS, RESUMES
6
  from utils.helpers import generate_id, utc_now, str_objectid
7
  from utils.skills import normalize_skill_list, find_matching_skills, find_missing_skills, build_interview_focus_skills
8
  from services.interview_graph import run_interview_graph
9
  from utils.gemini import generate_interview_question_batch, analyze_resume_vs_job_description
10
  from services.job_description_service import get_job_description_for_user
11
+ from services.tts_service import prefetch_wav
12
 
13
  MAX_QUESTIONS = 20
14
  SESSION_TTL = 7200 # 2 hours
 
44
  return default
45
 
46
 
47
+ def _normalize_voice_gender(value: str | None) -> str:
48
+ return "male" if (value or "").strip().lower() == "male" else "female"
49
+
50
+
51
+ def _consume_prefetch_task_result(task: asyncio.Task) -> None:
52
+ try:
53
+ task.result()
54
+ except Exception:
55
+ # Prefetch is optional; ignore failures to avoid noisy task warnings.
56
+ pass
57
+
58
+
59
+ def _schedule_question_audio_prefetch(questions: list[str], voice_gender: str) -> None:
60
+ for q in questions:
61
+ text = (q or "").strip()
62
+ if not text:
63
+ continue
64
+ try:
65
+ task = asyncio.create_task(prefetch_wav(text, voice_gender))
66
+ task.add_done_callback(_consume_prefetch_task_result)
67
+ except Exception:
68
+ # Best-effort optimization only.
69
+ pass
70
+
71
+
72
  def _normalize_bank_difficulty(value: str) -> str:
73
  difficulty = (value or "medium").strip().lower()
74
  if difficulty not in {"easy", "medium", "hard"}:
 
444
  session_id = generate_id()
445
  _LOCAL_SUMMARIES[session_id] = ""
446
 
447
+ user_doc = None
448
+ try:
449
+ user_doc = await db[USERS].find_one({"_id": ObjectId(user_id)}, {"speech_settings": 1})
450
+ except Exception:
451
+ user_doc = await db[USERS].find_one({"user_id": user_id}, {"speech_settings": 1})
452
+ speech_voice_gender = _normalize_voice_gender(((user_doc or {}).get("speech_settings") or {}).get("voice_gender"))
453
+
454
  session_doc = {
455
  "session_id": session_id,
456
  "user_id": user_id,
 
467
  "metrics_bank_questions": 0,
468
  "metrics_bank_shortfall": 0,
469
  "metrics_generation_batches": 0,
470
+ "speech_voice_gender": speech_voice_gender,
471
  "timer_enabled": timer_enabled,
472
  "timer_seconds": timer_seconds,
473
  "started_at": utc_now(),
 
493
  "timer_enabled": str(timer_enabled),
494
  "timer_seconds": str(timer_seconds or ""),
495
  "status": "in_progress",
496
+ "speech_voice_gender": speech_voice_gender,
497
  "metrics_gemini_calls": 0,
498
  "metrics_gemini_questions": 0,
499
  "metrics_bank_questions": 0,
 
527
  await redis.expire(f"session:{session_id}:pending_questions", SESSION_TTL)
528
 
529
  first_q_data = await redis.hgetall(f"session:{session_id}:q:{first_id}")
530
+ _schedule_question_audio_prefetch(
531
+ [
532
+ first_q_data.get("question", ""),
533
+ *[q.get("question", "") for q in selected[1:3]],
534
+ ],
535
+ speech_voice_gender,
536
+ )
537
+
538
  return {
539
  "session_id": session_id,
540
  "interview_type": "topic",
 
641
  db = get_db()
642
  redis = get_redis()
643
 
644
+ user_doc = None
645
+ try:
646
+ user_doc = await db[USERS].find_one({"_id": ObjectId(user_id)}, {"speech_settings": 1})
647
+ except Exception:
648
+ user_doc = await db[USERS].find_one({"user_id": user_id}, {"speech_settings": 1})
649
+ speech_voice_gender = _normalize_voice_gender(((user_doc or {}).get("speech_settings") or {}).get("voice_gender"))
650
+
651
  # Get user skills
652
  skills_doc = await db[SKILLS].find_one({"user_id": user_id})
653
  user_skills = skills_doc.get("skills", ["general"]) if skills_doc else ["general"]
 
730
  "metrics_bank_questions": initial_bank_questions,
731
  "metrics_bank_shortfall": initial_bank_shortfall,
732
  "metrics_generation_batches": 1,
733
+ "speech_voice_gender": speech_voice_gender,
734
  "started_at": utc_now(),
735
  }
736
  await db[SESSIONS].insert_one(session_doc)
 
753
  "current_difficulty": last_difficulty,
754
  "interview_type": "resume",
755
  "status": "in_progress",
756
+ "speech_voice_gender": speech_voice_gender,
757
  "metrics_gemini_calls": initial_gemini_calls,
758
  "metrics_gemini_questions": initial_gemini_questions,
759
  "metrics_bank_questions": initial_bank_questions,
 
772
  await redis.expire(f"session:{session_id}:pending_questions", SESSION_TTL)
773
 
774
  first_q_data = await redis.hgetall(f"session:{session_id}:q:{first_id}")
775
+ _schedule_question_audio_prefetch(
776
+ [
777
+ first_q_data.get("question", ""),
778
+ *[item.get("question", "") for item in initial_batch[1:4]],
779
+ ],
780
+ speech_voice_gender,
781
+ )
782
 
783
  return {
784
  "session_id": session_id,
 
944
  raise ValueError("Unable to fetch or generate next question")
945
 
946
  q_data = await redis.hgetall(f"session:{session_id}:q:{next_question_id}")
947
+ speech_voice_gender = _normalize_voice_gender(session.get("speech_voice_gender"))
948
+
949
+ # Prefetch the spoken audio for this question and one-ahead question.
950
+ prefetch_texts = [q_data.get("question", "")]
951
+ peek_next_id = await redis.lindex(f"session:{session_id}:pending_questions", 0)
952
+ if peek_next_id:
953
+ peek_q = await redis.hgetall(f"session:{session_id}:q:{peek_next_id}")
954
+ prefetch_texts.append(peek_q.get("question", ""))
955
+ _schedule_question_audio_prefetch(prefetch_texts, speech_voice_gender)
956
  next_difficulty = q_data.get("difficulty", session.get("current_difficulty", "medium"))
957
  new_count = question_count + 1
958
  new_served_count = served_count + 1
backend/services/stt_service.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import tempfile
4
+
5
+ # On Windows, ctranslate2 and torch can load separate OpenMP runtimes.
6
+ # Allowing duplicates avoids process aborts during model initialization.
7
+ os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
8
+
9
+ _WHISPER_MODEL_CACHE = {}
10
+ _WHISPER_MODEL_LOCK = asyncio.Lock()
11
+
12
+
13
+ def _resolve_device() -> str:
14
+ pref = os.getenv("WHISPER_DEVICE", "auto").strip().lower()
15
+ if pref in {"cpu", "cuda"}:
16
+ return pref
17
+
18
+ try:
19
+ import torch
20
+
21
+ return "cuda" if torch.cuda.is_available() else "cpu"
22
+ except Exception:
23
+ return "cpu"
24
+
25
+
26
+ def _resolve_compute_type(device: str) -> str:
27
+ pref = os.getenv("WHISPER_COMPUTE_TYPE", "auto").strip().lower()
28
+ if pref and pref != "auto":
29
+ return pref
30
+ return "float16" if device == "cuda" else "int8"
31
+
32
+
33
+ def _resolve_model_size() -> str:
34
+ return os.getenv("WHISPER_MODEL_SIZE", "base").strip() or "base"
35
+
36
+
37
+ async def _get_whisper_model():
38
+ model_size = _resolve_model_size()
39
+ device = _resolve_device()
40
+ compute_type = _resolve_compute_type(device)
41
+ cache_key = f"{model_size}|{device}|{compute_type}"
42
+
43
+ async with _WHISPER_MODEL_LOCK:
44
+ if cache_key in _WHISPER_MODEL_CACHE:
45
+ return _WHISPER_MODEL_CACHE[cache_key]
46
+
47
+ def _load_model():
48
+ try:
49
+ from faster_whisper import WhisperModel
50
+ except Exception as exc:
51
+ raise RuntimeError(
52
+ "faster-whisper is not installed in the active Python environment"
53
+ ) from exc
54
+
55
+ try:
56
+ return WhisperModel(model_size, device=device, compute_type=compute_type)
57
+ except Exception:
58
+ # Keep service resilient if GPU config mismatches runtime.
59
+ return WhisperModel(model_size, device="cpu", compute_type="int8")
60
+
61
+ model = await asyncio.to_thread(_load_model)
62
+ _WHISPER_MODEL_CACHE[cache_key] = model
63
+ return model
64
+
65
+
66
+ async def warmup_whisper_model() -> None:
67
+ try:
68
+ await _get_whisper_model()
69
+ except Exception:
70
+ # Best-effort warmup only.
71
+ pass
72
+
73
+
74
+ async def transcribe_audio_bytes(audio_bytes: bytes, filename: str = "speech.webm", language: str = "en") -> str:
75
+ if not audio_bytes:
76
+ raise ValueError("audio file is required")
77
+
78
+ model = await _get_whisper_model()
79
+ ext = os.path.splitext(filename or "speech.webm")[1] or ".webm"
80
+ target_language = (language or "en").strip().lower() or "en"
81
+
82
+ fd, tmp_path = tempfile.mkstemp(suffix=ext)
83
+ os.close(fd)
84
+
85
+ try:
86
+ with open(tmp_path, "wb") as f:
87
+ f.write(audio_bytes)
88
+
89
+ def _transcribe() -> str:
90
+ segments, _ = model.transcribe(
91
+ tmp_path,
92
+ language=target_language,
93
+ beam_size=1,
94
+ best_of=1,
95
+ vad_filter=True,
96
+ condition_on_previous_text=False,
97
+ temperature=0.0,
98
+ )
99
+ parts = []
100
+ for seg in segments:
101
+ text = (seg.text or "").strip()
102
+ if text:
103
+ parts.append(text)
104
+ return " ".join(parts).strip()
105
+
106
+ text = await asyncio.to_thread(_transcribe)
107
+ return text
108
+ finally:
109
+ if os.path.exists(tmp_path):
110
+ os.remove(tmp_path)
backend/services/tts_service.py CHANGED
@@ -2,9 +2,28 @@ import asyncio
2
  import os
3
  import tempfile
4
  from typing import Tuple
 
5
 
6
  _MODEL_CACHE = {}
7
  _MODEL_LOCK = asyncio.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def _select_model(voice_gender: str) -> Tuple[str, str | None]:
@@ -29,7 +48,27 @@ async def _get_tts_model(model_name: str):
29
  "Coqui TTS is not installed in the active Python environment"
30
  ) from exc
31
 
32
- # Use CPU by default for compatibility.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return TTS(model_name=model_name, progress_bar=False, gpu=False)
34
 
35
  model = await asyncio.to_thread(_load_model)
@@ -37,11 +76,78 @@ async def _get_tts_model(model_name: str):
37
  return model
38
 
39
 
40
- async def synthesize_wav(text: str, voice_gender: str = "female") -> bytes:
41
- content = (text or "").strip()
42
- if not content:
43
- raise ValueError("text is required")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
 
 
 
 
 
 
45
  model_name, speaker = _select_model(voice_gender)
46
  tts = await _get_tts_model(model_name)
47
 
@@ -50,7 +156,7 @@ async def synthesize_wav(text: str, voice_gender: str = "female") -> bytes:
50
  try:
51
  def _synthesize():
52
  kwargs = {
53
- "text": content,
54
  "file_path": tmp_path,
55
  }
56
  if speaker:
@@ -63,3 +169,51 @@ async def synthesize_wav(text: str, voice_gender: str = "female") -> bytes:
63
  finally:
64
  if os.path.exists(tmp_path):
65
  os.remove(tmp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import tempfile
4
  from typing import Tuple
5
+ from collections import OrderedDict
6
 
7
  _MODEL_CACHE = {}
8
  _MODEL_LOCK = asyncio.Lock()
9
+ _AUDIO_CACHE = OrderedDict()
10
+ _AUDIO_CACHE_LOCK = asyncio.Lock()
11
+
12
+ XTTS_MODEL = "tts_models/multilingual/multi-dataset/xtts_v2"
13
+ XTTS_LANGUAGE = "en"
14
+ XTTS_SPEED = 1.2
15
+ MAX_TEXT_LENGTH = 220
16
+ _XTTS_WARM = False
17
+ AUDIO_CACHE_MAX_ITEMS = 300
18
+
19
+ # User-approved stable voices:
20
+ # - Female: index 45 => Alexandra Hisakawa
21
+ # - Male: index 21 => Abrahan Mack
22
+ XTTS_SPEAKER_BY_GENDER = {
23
+ "female": "Alexandra Hisakawa",
24
+ "male": "Abrahan Mack",
25
+ "auto": "Alexandra Hisakawa",
26
+ }
27
 
28
 
29
  def _select_model(voice_gender: str) -> Tuple[str, str | None]:
 
48
  "Coqui TTS is not installed in the active Python environment"
49
  ) from exc
50
 
51
+ gpu_pref = os.getenv("XTTS_USE_GPU", "auto").strip().lower()
52
+ use_gpu = False
53
+ if gpu_pref in {"1", "true", "yes", "on"}:
54
+ use_gpu = True
55
+ elif gpu_pref in {"0", "false", "no", "off"}:
56
+ use_gpu = False
57
+ else:
58
+ try:
59
+ import torch
60
+
61
+ use_gpu = bool(torch.cuda.is_available())
62
+ except Exception:
63
+ use_gpu = False
64
+
65
+ if use_gpu:
66
+ try:
67
+ return TTS(model_name=model_name, progress_bar=False, gpu=True)
68
+ except Exception:
69
+ # Graceful CPU fallback when CUDA runtime is unavailable/mismatched.
70
+ return TTS(model_name=model_name, progress_bar=False, gpu=False)
71
+
72
  return TTS(model_name=model_name, progress_bar=False, gpu=False)
73
 
74
  model = await asyncio.to_thread(_load_model)
 
76
  return model
77
 
78
 
79
+ def _resolve_xtts_speaker(voice_gender: str) -> str:
80
+ gender = (voice_gender or "female").strip().lower()
81
+ if gender not in XTTS_SPEAKER_BY_GENDER:
82
+ gender = "female"
83
+ return XTTS_SPEAKER_BY_GENDER[gender]
84
+
85
+
86
+ def _truncate_text(value: str, max_length: int = MAX_TEXT_LENGTH) -> str:
87
+ content = " ".join((value or "").strip().split())
88
+ if len(content) <= max_length:
89
+ return content
90
+ trimmed = content[:max_length].rstrip()
91
+ # Keep sentence boundaries cleaner when truncating.
92
+ for marker in ("?", "!", "."):
93
+ if marker in trimmed:
94
+ head = trimmed.rsplit(marker, 1)[0].strip()
95
+ if len(head) >= max_length // 2:
96
+ return f"{head}{marker}"
97
+ return trimmed
98
+
99
+
100
+ async def warmup_xtts_model() -> None:
101
+ """Preload XTTS to avoid long cold-start on first interview question."""
102
+ global _XTTS_WARM
103
+ if _XTTS_WARM:
104
+ return
105
+ try:
106
+ await _get_tts_model(XTTS_MODEL)
107
+ _XTTS_WARM = True
108
+ except Exception:
109
+ # Keep API startup resilient; synthesis route still has fallbacks.
110
+ pass
111
+
112
+
113
+ def _synthesize_xtts_to_file(tts, text: str, speaker: str, file_path: str) -> None:
114
+ kwargs = {
115
+ "text": text,
116
+ "file_path": file_path,
117
+ "speaker": speaker,
118
+ "language": XTTS_LANGUAGE,
119
+ }
120
+ try:
121
+ # Faster delivery for interview prompts.
122
+ tts.tts_to_file(**kwargs, speed=XTTS_SPEED)
123
+ except TypeError:
124
+ # Some model/runtime combinations may not expose speed arg.
125
+ tts.tts_to_file(**kwargs)
126
+
127
+
128
+ def _build_audio_cache_key(text: str, voice_gender: str) -> str:
129
+ return f"{(voice_gender or 'female').strip().lower()}::{text.strip()}"
130
+
131
+
132
+ async def _get_cached_audio(cache_key: str) -> bytes | None:
133
+ async with _AUDIO_CACHE_LOCK:
134
+ value = _AUDIO_CACHE.get(cache_key)
135
+ if value is None:
136
+ return None
137
+ # LRU touch.
138
+ _AUDIO_CACHE.move_to_end(cache_key)
139
+ return value
140
+
141
 
142
+ async def _set_cached_audio(cache_key: str, data: bytes) -> None:
143
+ async with _AUDIO_CACHE_LOCK:
144
+ _AUDIO_CACHE[cache_key] = data
145
+ _AUDIO_CACHE.move_to_end(cache_key)
146
+ while len(_AUDIO_CACHE) > AUDIO_CACHE_MAX_ITEMS:
147
+ _AUDIO_CACHE.popitem(last=False)
148
+
149
+
150
+ async def _synthesize_fallback_wav(text: str, voice_gender: str) -> bytes:
151
  model_name, speaker = _select_model(voice_gender)
152
  tts = await _get_tts_model(model_name)
153
 
 
156
  try:
157
  def _synthesize():
158
  kwargs = {
159
+ "text": text,
160
  "file_path": tmp_path,
161
  }
162
  if speaker:
 
169
  finally:
170
  if os.path.exists(tmp_path):
171
  os.remove(tmp_path)
172
+
173
+
174
+ async def prefetch_wav(text: str, voice_gender: str = "female") -> None:
175
+ """Best-effort speech prefetch to warm audio cache."""
176
+ try:
177
+ await synthesize_wav(text, voice_gender)
178
+ except Exception:
179
+ # Silent prefetch failure; runtime synth may still succeed later.
180
+ pass
181
+
182
+
183
+ async def synthesize_wav(text: str, voice_gender: str = "female") -> bytes:
184
+ content = _truncate_text(text)
185
+ if not content:
186
+ raise ValueError("text is required")
187
+
188
+ normalized_gender = (voice_gender or "female").strip().lower()
189
+ if normalized_gender not in {"male", "female", "auto"}:
190
+ normalized_gender = "female"
191
+
192
+ cache_key = _build_audio_cache_key(content, normalized_gender)
193
+ cached = await _get_cached_audio(cache_key)
194
+ if cached:
195
+ return cached
196
+
197
+ speaker = _resolve_xtts_speaker(normalized_gender)
198
+ tts = await _get_tts_model(XTTS_MODEL)
199
+
200
+ fd, tmp_path = tempfile.mkstemp(suffix=".wav")
201
+ os.close(fd)
202
+ try:
203
+ def _synthesize():
204
+ _synthesize_xtts_to_file(tts, text=content, speaker=speaker, file_path=tmp_path)
205
+
206
+ try:
207
+ await asyncio.to_thread(_synthesize)
208
+ with open(tmp_path, "rb") as f:
209
+ wav = f.read()
210
+ await _set_cached_audio(cache_key, wav)
211
+ return wav
212
+ except Exception:
213
+ # Keep speech available even if XTTS runtime has temporary issues.
214
+ wav = await _synthesize_fallback_wav(content, normalized_gender)
215
+ await _set_cached_audio(cache_key, wav)
216
+ return wav
217
+ finally:
218
+ if os.path.exists(tmp_path):
219
+ os.remove(tmp_path)