| """ |
| app.py — FastAPI entrypoint: WebRTC-first + WebSocket fallback |
| |
| FIXES APPLIED: |
| FIX-SESSION (Issue 1): The voice WS handler now reads user_id from the |
| first 'init' JSON message before processing any audio. The browser now |
| generates a fresh USER_ID on every page load, so each reload becomes a |
| brand-new user and gets a fresh LangGraph thread / DB row. |
| |
| Implementation: |
| • user_id is initialised to None inside ws_voice. |
| • The handler waits for any early text messages before processing binary. |
| • On 'init' message, user_id is set and init_ack returned. |
| • All subsequent audio/LLM calls use that session user_id. |
| • If no 'init' is received within 3 s, a random fallback is used |
| (prevents hang for non-browser clients). |
| |
| FIX-CHAT-INIT (Issue 1): ws_chat also reads the 'init' message so chat |
| sessions share the same backend thread as voice sessions for the same |
| user. |
| |
| All performance optimisations (parallel TTS, GPU-batched STT, concurrent |
| LLM+TTS) preserved. |
| """ |
|
|
| import asyncio |
| import json |
| import os |
| import re |
| import struct |
| import uuid |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
|
|
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request |
| from fastapi.responses import FileResponse, HTMLResponse, JSONResponse |
| from fastapi.staticfiles import StaticFiles |
| from starlette.websockets import WebSocketState |
|
|
| from core.backend import AIBackend |
| from services.stt import STTProcessor |
| from services.streaming import ParallelTTSStreamer |
| from db_view.dbapi import app as db_api_app |
|
|
| |
| try: |
| from services.webrtc_pipeline import WebRTCSession |
| WEBRTC_AVAILABLE = True |
| print("[APP] WebRTC pipeline available ✓") |
| except (ImportError, RuntimeError) as _e: |
| WEBRTC_AVAILABLE = False |
| print(f"[APP] WebRTC pipeline unavailable ({_e}). WebSocket fallback only.") |
|
|
| |
| |
| |
| USE_GEMINI = True |
| USE_OLLAMA = False |
| USE_LOCAL_FALLBACK = False |
|
|
| _active = sum([USE_GEMINI, USE_OLLAMA, USE_LOCAL_FALLBACK]) |
| if _active != 1: |
| raise RuntimeError( |
| f"[CONFIG] Exactly one of USE_GEMINI / USE_OLLAMA / USE_LOCAL_FALLBACK " |
| f"must be True. Got {_active}." |
| ) |
|
|
| ai = AIBackend( |
| use_gemini=USE_GEMINI, |
| use_ollama=USE_OLLAMA, |
| use_fallback=USE_LOCAL_FALLBACK, |
| ) |
|
|
| _rtc_sessions: dict[str, "WebRTCSession"] = {} |
|
|
|
|
| |
| |
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| await ai.async_setup() |
| print("[APP] AI backend ready ✓") |
| yield |
| for session in list(_rtc_sessions.values()): |
| await session.close() |
| _rtc_sessions.clear() |
| conn = getattr(ai, "conn", None) |
| if conn: |
| try: |
| await conn.close() |
| except Exception: |
| pass |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
| BASE_DIR = Path(__file__).resolve().parent |
| FRONTEND_DIR = BASE_DIR / "frontend" |
|
|
| try: |
| app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static") |
| except Exception: |
| pass |
|
|
| try: |
| app.mount("/dbapi", db_api_app) |
| except Exception: |
| pass |
|
|
|
|
| @app.get("/") |
| async def root(): |
| index_path = FRONTEND_DIR / "index.html" |
| if index_path.exists(): |
| return FileResponse(str(index_path)) |
| return HTMLResponse("<h2>frontend/index.html not found</h2>", status_code=404) |
|
|
|
|
| @app.get("/db-view") |
| async def db_view(): |
| db_index_path = BASE_DIR / "db_view" / "db.html" |
| if db_index_path.exists(): |
| return FileResponse(str(db_index_path)) |
| return HTMLResponse("<h2>db_view/db.html not found</h2>", status_code=404) |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| from services.stt import _model_ready, _model_error |
| return JSONResponse({ |
| "status": "ok", |
| "model_ready": _model_ready.is_set(), |
| "model_error": _model_error, |
| "rtc_sessions": len(_rtc_sessions), |
| }) |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/rtc/offer") |
| async def rtc_offer(request: Request): |
| if not WEBRTC_AVAILABLE: |
| return JSONResponse( |
| {"error": "WebRTC unavailable. Use WebSocket fallback at /ws/voice"}, |
| status_code=503, |
| ) |
| body = await request.json() |
| sdp = body.get("sdp", "") |
| sdp_type = body.get("type", "offer") |
| session_id = body.get("session_id") or uuid.uuid4().hex |
|
|
| session = _rtc_sessions.get(session_id) |
| if session is None: |
| session = WebRTCSession(ai_backend=ai) |
| _rtc_sessions[session_id] = session |
| print(f"[RTC] New session: {session_id} user_id={session.user_id}") |
|
|
| answer = await session.handle_offer(sdp, sdp_type) |
| return JSONResponse({**answer, "session_id": session_id}) |
|
|
|
|
| @app.post("/rtc/ice") |
| async def rtc_ice(request: Request): |
| if not WEBRTC_AVAILABLE: |
| return JSONResponse({"error": "WebRTC unavailable"}, status_code=503) |
| body = await request.json() |
| session_id = body.get("session_id", "") |
| candidate = body.get("candidate", {}) |
| session = _rtc_sessions.get(session_id) |
| if session is None: |
| return JSONResponse({"error": "Session not found"}, status_code=404) |
| await session.add_ice_candidate(candidate) |
| return JSONResponse({"ok": True}) |
|
|
|
|
| @app.delete("/rtc/session/{session_id}") |
| async def rtc_close(session_id: str): |
| session = _rtc_sessions.pop(session_id, None) |
| if session: |
| await session.close() |
| return JSONResponse({"ok": True}) |
|
|
|
|
| |
| |
| |
|
|
| _DIGIT_WORDS = { |
| "0": "শূন্য", |
| "1": "এক", |
| "2": "দুই", |
| "3": "তিন", |
| "4": "চার", |
| "5": "পাঁচ", |
| "6": "ছয়", |
| "7": "সাত", |
| "8": "আট", |
| "9": "নয়", |
| "০": "শূন্য", |
| "১": "এক", |
| "২": "দুই", |
| "৩": "তিন", |
| "৪": "চার", |
| "৫": "পাঁচ", |
| "৬": "ছয়", |
| "৭": "সাত", |
| "৮": "আট", |
| "৯": "নয়", |
| "٠": "শূন্য", |
| "١": "এক", |
| "٢": "দুই", |
| "٣": "তিন", |
| "٤": "চার", |
| "٥": "পাঁচ", |
| "٦": "ছয়", |
| "٧": "সাত", |
| "٨": "আট", |
| "٩": "নয়", |
| } |
|
|
|
|
| def _spoken_digits(chunk: str) -> str: |
| digits = [ch for ch in chunk if ch in _DIGIT_WORDS] |
| if len(digits) < 10: |
| return chunk |
| spoken = " ".join(_DIGIT_WORDS[ch] for ch in digits) |
| return spoken |
|
|
|
|
| def _expand_phone_like_numbers(text: str) -> str: |
| if not text: |
| return "" |
|
|
| def repl(match: re.Match[str]) -> str: |
| chunk = match.group(0) |
| spoken = _spoken_digits(chunk) |
| if spoken == chunk: |
| return chunk |
|
|
| prev_char = text[match.start() - 1] if match.start() > 0 else "" |
| next_char = text[match.end()] if match.end() < len(text) else "" |
|
|
| if prev_char and not prev_char.isspace() and prev_char not in "([<{\"'": |
| spoken = " " + spoken |
| if next_char and not next_char.isspace() and next_char not in ")]>.,!?;:}\"'": |
| spoken = spoken + " " |
| return spoken |
|
|
| return re.sub(r"[+\d০-৯٠-٩][\d০-৯٠-٩\s().\-]{8,}[\d০-৯٠-٩]", repl, text) |
|
|
|
|
| def _normalize_ai_text(text: str) -> str: |
| """ |
| Apply small UX wording normalizations to assistant-visible text. |
| (We still instruct the model via system prompt, but this guarantees output.) |
| """ |
| if not text: |
| return "" |
| out = text |
| out = out.replace("উপলব্ধ", "এভেলেবেল") |
| out = out.replace("জ্বি", "আচ্ছা") |
| out = _expand_phone_like_numbers(out) |
| return out |
|
|
|
|
| def _ws_open(ws: WebSocket) -> bool: |
| return ws.client_state == WebSocketState.CONNECTED |
|
|
|
|
| async def _safe_text(ws: WebSocket, payload: dict) -> bool: |
| if not _ws_open(ws): |
| return False |
| try: |
| await ws.send_text(json.dumps(payload)) |
| return True |
| except Exception: |
| return False |
|
|
|
|
| async def _safe_bytes(ws: WebSocket, data: bytes) -> bool: |
| if not _ws_open(ws): |
| return False |
| try: |
| await ws.send_bytes(data) |
| return True |
| except Exception: |
| return False |
|
|
|
|
| async def _register_user(user_id: str) -> None: |
| if user_id: |
| await ai.ensure_user_thread(user_id) |
|
|
|
|
| |
| |
| |
|
|
| @app.websocket("/ws/chat") |
| async def ws_chat(ws: WebSocket): |
| await ws.accept() |
| print("[CHAT] Client connected ✓") |
|
|
| |
| user_id: str = "" |
|
|
| try: |
| while True: |
| raw = await ws.receive_text() |
| try: |
| data = json.loads(raw) |
| except json.JSONDecodeError: |
| await _safe_text(ws, {"type": "error", "text": "Invalid JSON"}) |
| continue |
|
|
| msg_type = data.get("type", "") |
|
|
| |
| if msg_type == "init": |
| claimed = str(data.get("user_id", "")).strip()[:64] |
| if claimed: |
| user_id = claimed |
| print(f"[CHAT] Session restored for user_id={user_id!r}") |
| await _register_user(user_id) |
| await _safe_text(ws, {"type": "init_ack", "user_id": user_id}) |
| continue |
|
|
| if msg_type == "ping": |
| await _safe_text(ws, {"type": "pong"}) |
| continue |
|
|
| |
| if not user_id: |
| user_id = str(data.get("user_id", "default_user"))[:64] |
| await _register_user(user_id) |
|
|
| user_query = data.get("user_query", "").strip() |
| if not user_query: |
| continue |
|
|
| print(f"[CHAT] user_id={user_id!r} query={user_query!r}") |
|
|
| try: |
| stream = await ai.main(user_id, user_query) |
| full_text = "" |
| async for token in stream: |
| if token: |
| token = _normalize_ai_text(token) |
| full_text += token |
| await _safe_text(ws, {"type": "llm_token", "token": token}) |
| |
| if full_text: |
| await _safe_text(ws, {"type": "chat", "text": _normalize_ai_text(full_text)}) |
| except Exception as exc: |
| import traceback; traceback.print_exc() |
| await _safe_text(ws, {"type": "error", "text": str(exc)}) |
|
|
| await _safe_text(ws, {"type": "end"}) |
|
|
| except WebSocketDisconnect: |
| print("[CHAT] Client disconnected") |
| except Exception as exc: |
| if "disconnect" not in str(exc).lower(): |
| print(f"[CHAT] Error: {exc}") |
|
|
|
|
| |
| |
| |
|
|
| |
| _INIT_TIMEOUT = 3.0 |
|
|
|
|
| @app.websocket("/ws/voice") |
| async def ws_voice(ws: WebSocket): |
| await ws.accept() |
| print("[VOICE] Client connected") |
|
|
| |
| |
| |
| user_id: str = "" |
| try: |
| first_raw = await asyncio.wait_for(ws.receive(), timeout=_INIT_TIMEOUT) |
| if "text" in first_raw and first_raw["text"]: |
| try: |
| first_msg = json.loads(first_raw["text"]) |
| if first_msg.get("type") == "init": |
| claimed = str(first_msg.get("user_id", "")).strip()[:64] |
| if claimed: |
| user_id = claimed |
| except (json.JSONDecodeError, KeyError): |
| pass |
| except asyncio.TimeoutError: |
| print("[VOICE] No init message within timeout — using fallback user_id") |
|
|
| if not user_id: |
| user_id = f"voice_{uuid.uuid4().hex[:12]}" |
| print(f"[VOICE] Fallback user_id={user_id}") |
| else: |
| print(f"[VOICE] Session user_id={user_id}") |
|
|
| await _register_user(user_id) |
| await _safe_text(ws, {"type": "init_ack", "user_id": user_id}) |
|
|
| stt = STTProcessor() |
| _active_streamer: ParallelTTSStreamer | None = None |
| _active_task: asyncio.Task | None = None |
| |
| |
| _utterance_q: asyncio.Queue[object | None] = asyncio.Queue() |
| _worker_task: asyncio.Task | None = None |
| _turn_id: int = 0 |
| brain_mode_enabled = False |
|
|
| async def _cancel_active(): |
| nonlocal _active_streamer, _active_task |
| if _active_streamer is not None: |
| await _active_streamer.cancel() |
| _active_streamer = None |
| if _active_task is not None and not _active_task.done(): |
| _active_task.cancel() |
| try: |
| await _active_task |
| except (asyncio.CancelledError, Exception): |
| pass |
| _active_task = None |
|
|
| async def _drain_utterance_queue(): |
| while True: |
| try: |
| _utterance_q.get_nowait() |
| except asyncio.QueueEmpty: |
| break |
|
|
| async def _handle_speak(text: str): |
| """ |
| Generate TTS for a given text without running STT. |
| Uses the same framed-audio protocol as normal turns and emits `llm_full` |
| so the UI can display the spoken text. |
| """ |
| nonlocal _active_streamer |
|
|
| speak_text = _normalize_ai_text((text or "").strip()) |
| if not speak_text: |
| await _safe_text(ws, {"type": "end"}) |
| return |
|
|
| nonlocal _turn_id |
| _turn_id += 1 |
| my_turn = _turn_id |
|
|
| tts_streamer = ParallelTTSStreamer() |
| _active_streamer = tts_streamer |
| audio_seq = 0 |
|
|
| async def run_text(): |
| try: |
| await _safe_text(ws, {"type": "llm_full", "text": speak_text, "turn": my_turn}) |
| await tts_streamer.add_token(speak_text) |
| except asyncio.CancelledError: |
| raise |
| finally: |
| await tts_streamer.flush() |
|
|
| async def run_tts_framed(): |
| nonlocal audio_seq |
| async for chunk in tts_streamer.stream_audio(): |
| framed = struct.pack(">II", my_turn, audio_seq) + chunk |
| if not await _safe_bytes(ws, framed): |
| break |
| audio_seq += 1 |
|
|
| await asyncio.gather(run_text(), run_tts_framed(), return_exceptions=True) |
| _active_streamer = None |
| await _safe_text(ws, {"type": "end"}) |
|
|
| async def _handle_utterance(audio_bytes: bytes): |
| nonlocal _active_streamer |
| nonlocal _turn_id |
| nonlocal brain_mode_enabled |
|
|
| |
| transcript = await stt.transcribe(audio_bytes) |
| if not transcript: |
| |
| await _safe_text(ws, {"type": "error", "text": "কথা বুঝতে পারিনি, আবার বলুন।"}) |
| await _safe_text(ws, {"type": "end"}) |
| return |
|
|
| print(f"[VOICE] [{user_id}] STT: {transcript}") |
| _turn_id += 1 |
| my_turn = _turn_id |
|
|
| if not await _safe_text(ws, {"type": "stt", "text": transcript, "turn": my_turn}): |
| return |
|
|
| if brain_mode_enabled: |
| |
| tts_streamer = ParallelTTSStreamer() |
| _active_streamer = tts_streamer |
| audio_seq = 0 |
|
|
| async def run_llm(): |
| full_text = "" |
| try: |
| stream = await ai.main(user_id, transcript) |
| async for token in stream: |
| if not token: |
| continue |
| token = _normalize_ai_text(token) |
| full_text += token |
| if not await _safe_text(ws, {"type": "llm_token", "token": token, "turn": my_turn}): |
| break |
| except asyncio.CancelledError: |
| raise |
| except Exception as exc: |
| print(f"[VOICE] LLM error: {exc}") |
| finally: |
| if full_text: |
| await _safe_text(ws, {"type": "llm_full", "text": _normalize_ai_text(full_text), "turn": my_turn}) |
| await tts_streamer.add_token(full_text) |
| await tts_streamer.flush() |
|
|
| async def run_tts_framed(): |
| nonlocal audio_seq |
| async for chunk in tts_streamer.stream_audio(): |
| framed = struct.pack(">II", my_turn, audio_seq) + chunk |
| if not await _safe_bytes(ws, framed): |
| break |
| audio_seq += 1 |
|
|
| await asyncio.gather(run_llm(), run_tts_framed(), return_exceptions=True) |
| _active_streamer = None |
| else: |
| |
| audio_seq = 0 |
|
|
| async def run_llm(): |
| full_text = "" |
| try: |
| stream = await ai.main(user_id, transcript) |
| async for token in stream: |
| if not token: |
| continue |
| token = _normalize_ai_text(token) |
| full_text += token |
| if not await _safe_text(ws, {"type": "llm_token", "token": token, "turn": my_turn}): |
| break |
| except asyncio.CancelledError: |
| raise |
| except Exception as exc: |
| print(f"[VOICE] LLM error: {exc}") |
| return full_text |
|
|
| full_text = await run_llm() |
| if full_text: |
| await _safe_text(ws, {"type": "llm_full", "text": _normalize_ai_text(full_text), "turn": my_turn}) |
| tts_streamer = ParallelTTSStreamer() |
| _active_streamer = tts_streamer |
| await tts_streamer.add_token(full_text) |
| await tts_streamer.flush() |
| async for chunk in tts_streamer.stream_audio(): |
| framed = struct.pack(">II", my_turn, audio_seq) + chunk |
| if not await _safe_bytes(ws, framed): |
| break |
| audio_seq += 1 |
| _active_streamer = None |
|
|
| await _safe_text(ws, {"type": "end"}) |
|
|
| async def _utterance_worker(): |
| nonlocal _active_task |
| while True: |
| item = await _utterance_q.get() |
| if item is None: |
| break |
| try: |
| |
| |
| if isinstance(item, (bytes, bytearray)): |
| _active_task = asyncio.create_task(_handle_utterance(bytes(item))) |
| elif isinstance(item, dict) and item.get("type") == "speak": |
| _active_task = asyncio.create_task(_handle_speak(str(item.get("text", "")))) |
| else: |
| continue |
| await _active_task |
| except asyncio.CancelledError: |
| |
| pass |
| except Exception as exc: |
| print(f"[VOICE] Utterance worker error: {exc}") |
| await _safe_text(ws, {"type": "error", "text": str(exc)}) |
| await _safe_text(ws, {"type": "end"}) |
| finally: |
| _active_task = None |
|
|
| try: |
| _worker_task = asyncio.create_task(_utterance_worker()) |
| while True: |
| if not _ws_open(ws): |
| break |
|
|
| try: |
| data = await ws.receive() |
| except WebSocketDisconnect: |
| break |
| except Exception as exc: |
| if "disconnect" in str(exc).lower(): |
| break |
| print(f"[VOICE] Receive error: {exc}") |
| break |
|
|
| if "bytes" in data and data["bytes"]: |
| audio_bytes = data["bytes"] |
| print(f"[VOICE] [{user_id}] Utterance: {len(audio_bytes):,} bytes") |
| |
| |
| if _active_task is not None and not _active_task.done(): |
| await _cancel_active() |
| await _drain_utterance_queue() |
| await _utterance_q.put(audio_bytes) |
|
|
| elif "text" in data and data["text"]: |
| try: |
| msg = json.loads(data["text"]) |
| t = msg.get("type", "") |
| if t == "init": |
| |
| claimed = str(msg.get("user_id", "")).strip()[:64] |
| if claimed: |
| user_id = claimed |
| await _register_user(user_id) |
| await _safe_text(ws, {"type": "init_ack", "user_id": user_id}) |
| elif t in ("brain_mode", "mode"): |
| brain_mode_enabled = bool(msg.get("enabled", False)) |
| elif t == "ping": |
| await _safe_text(ws, {"type": "pong"}) |
| elif t == "cancel": |
| await _cancel_active() |
| await _drain_utterance_queue() |
| await _safe_text(ws, {"type": "end"}) |
| elif t == "speak": |
| |
| |
| speak_text = str(msg.get("text", "")).strip() |
| if speak_text: |
| if _active_task is not None and not _active_task.done(): |
| await _cancel_active() |
| await _drain_utterance_queue() |
| await _utterance_q.put({"type": "speak", "text": speak_text}) |
| except json.JSONDecodeError: |
| pass |
|
|
| except WebSocketDisconnect: |
| pass |
| except Exception as exc: |
| if "disconnect" not in str(exc).lower(): |
| print(f"[VOICE] Error: {exc}") |
| finally: |
| await _utterance_q.put(None) |
| if _worker_task is not None and not _worker_task.done(): |
| _worker_task.cancel() |
| try: |
| await _worker_task |
| except (asyncio.CancelledError, Exception): |
| pass |
| await _cancel_active() |
| print(f"[VOICE] [{user_id}] Handler exiting cleanly.") |
|
|