| """ |
| app.py β FastAPI entrypoint: WebRTC-first + WebSocket fallback |
| |
| FIXES APPLIED: |
| FIX-SESSION (Issue 1): The voice WS handler now reads user_id from the |
| first 'init' JSON message before processing any audio. The variable is |
| no longer a random UUID per connection β it is the stable USER_ID sent |
| by the browser from localStorage. This means every reconnect, even after |
| a page reload, hits the same LangGraph thread and restores conversation |
| history. |
| |
| Implementation: |
| β’ user_id is initialised to None inside ws_voice. |
| β’ The handler waits for any early text messages before processing binary. |
| β’ On 'init' message, user_id is set and init_ack returned. |
| β’ All subsequent audio/LLM calls use that stable user_id. |
| β’ If no 'init' is received within 3 s, a random fallback is used |
| (prevents hang for non-browser clients). |
| |
| FIX-CHAT-INIT (Issue 1): ws_chat also reads the 'init' message so chat |
| sessions share the same backend thread as voice sessions for the same |
| user. |
| |
| All performance optimisations (parallel TTS, GPU-batched STT, concurrent |
| LLM+TTS) preserved. |
| """ |
|
|
| import asyncio |
| import json |
| import os |
| import uuid |
| from contextlib import asynccontextmanager |
|
|
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request |
| from fastapi.responses import FileResponse, HTMLResponse, JSONResponse |
| from fastapi.staticfiles import StaticFiles |
| from starlette.websockets import WebSocketState |
|
|
| from core.backend import AIBackend |
| from services.stt import STTProcessor |
| from services.streaming import ParallelTTSStreamer |
|
|
| |
| try: |
| from services.webrtc_pipeline import WebRTCSession |
| WEBRTC_AVAILABLE = True |
| print("[APP] WebRTC pipeline available β") |
| except (ImportError, RuntimeError) as _e: |
| WEBRTC_AVAILABLE = False |
| print(f"[APP] WebRTC pipeline unavailable ({_e}). WebSocket fallback only.") |
|
|
| |
| |
| |
| USE_GEMINI = False |
| USE_OLLAMA = True |
| USE_LOCAL_FALLBACK = False |
|
|
| _active = sum([USE_GEMINI, USE_OLLAMA, USE_LOCAL_FALLBACK]) |
| if _active != 1: |
| raise RuntimeError( |
| f"[CONFIG] Exactly one of USE_GEMINI / USE_OLLAMA / USE_LOCAL_FALLBACK " |
| f"must be True. Got {_active}." |
| ) |
|
|
| ai = AIBackend( |
| use_gemini=USE_GEMINI, |
| use_ollama=USE_OLLAMA, |
| use_fallback=USE_LOCAL_FALLBACK, |
| ) |
|
|
| _rtc_sessions: dict[str, "WebRTCSession"] = {} |
|
|
|
|
| |
| |
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| await ai.async_setup() |
| print("[APP] AI backend ready β") |
| yield |
| for session in list(_rtc_sessions.values()): |
| await session.close() |
| _rtc_sessions.clear() |
| conn = getattr(ai, "conn", None) |
| if conn: |
| try: |
| await conn.close() |
| except Exception: |
| pass |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
| try: |
| app.mount("/static", StaticFiles(directory="."), name="static") |
| except Exception: |
| pass |
|
|
|
|
| @app.get("/") |
| async def root(): |
| if os.path.exists("index.html"): |
| return FileResponse("index.html") |
| return HTMLResponse("<h2>index.html not found</h2>", status_code=404) |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| from services.stt import _model_ready, _model_error |
| return JSONResponse({ |
| "status": "ok", |
| "model_ready": _model_ready.is_set(), |
| "model_error": _model_error, |
| "rtc_sessions": len(_rtc_sessions), |
| }) |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/rtc/offer") |
| async def rtc_offer(request: Request): |
| if not WEBRTC_AVAILABLE: |
| return JSONResponse( |
| {"error": "WebRTC unavailable. Use WebSocket fallback at /ws/voice"}, |
| status_code=503, |
| ) |
| body = await request.json() |
| sdp = body.get("sdp", "") |
| sdp_type = body.get("type", "offer") |
| session_id = body.get("session_id") or uuid.uuid4().hex |
|
|
| session = _rtc_sessions.get(session_id) |
| if session is None: |
| session = WebRTCSession(ai_backend=ai) |
| _rtc_sessions[session_id] = session |
| print(f"[RTC] New session: {session_id} user_id={session.user_id}") |
|
|
| answer = await session.handle_offer(sdp, sdp_type) |
| return JSONResponse({**answer, "session_id": session_id}) |
|
|
|
|
| @app.post("/rtc/ice") |
| async def rtc_ice(request: Request): |
| if not WEBRTC_AVAILABLE: |
| return JSONResponse({"error": "WebRTC unavailable"}, status_code=503) |
| body = await request.json() |
| session_id = body.get("session_id", "") |
| candidate = body.get("candidate", {}) |
| session = _rtc_sessions.get(session_id) |
| if session is None: |
| return JSONResponse({"error": "Session not found"}, status_code=404) |
| await session.add_ice_candidate(candidate) |
| return JSONResponse({"ok": True}) |
|
|
|
|
| @app.delete("/rtc/session/{session_id}") |
| async def rtc_close(session_id: str): |
| session = _rtc_sessions.pop(session_id, None) |
| if session: |
| await session.close() |
| return JSONResponse({"ok": True}) |
|
|
|
|
| |
| |
| |
|
|
| def _ws_open(ws: WebSocket) -> bool: |
| return ws.client_state == WebSocketState.CONNECTED |
|
|
|
|
| async def _safe_text(ws: WebSocket, payload: dict) -> bool: |
| if not _ws_open(ws): |
| return False |
| try: |
| await ws.send_text(json.dumps(payload)) |
| return True |
| except Exception: |
| return False |
|
|
|
|
| async def _safe_bytes(ws: WebSocket, data: bytes) -> bool: |
| if not _ws_open(ws): |
| return False |
| try: |
| await ws.send_bytes(data) |
| return True |
| except Exception: |
| return False |
|
|
|
|
| |
| |
| |
|
|
| @app.websocket("/ws/chat") |
| async def ws_chat(ws: WebSocket): |
| await ws.accept() |
| print("[CHAT] Client connected β") |
|
|
| |
| user_id: str = "" |
|
|
| try: |
| while True: |
| raw = await ws.receive_text() |
| try: |
| data = json.loads(raw) |
| except json.JSONDecodeError: |
| await _safe_text(ws, {"type": "error", "text": "Invalid JSON"}) |
| continue |
|
|
| msg_type = data.get("type", "") |
|
|
| |
| if msg_type == "init": |
| claimed = str(data.get("user_id", "")).strip()[:64] |
| if claimed: |
| user_id = claimed |
| print(f"[CHAT] Session restored for user_id={user_id!r}") |
| await _safe_text(ws, {"type": "init_ack", "user_id": user_id}) |
| continue |
|
|
| if msg_type == "ping": |
| await _safe_text(ws, {"type": "pong"}) |
| continue |
|
|
| |
| if not user_id: |
| user_id = str(data.get("user_id", "default_user"))[:64] |
|
|
| user_query = data.get("user_query", "").strip() |
| if not user_query: |
| continue |
|
|
| print(f"[CHAT] user_id={user_id!r} query={user_query!r}") |
|
|
| try: |
| stream = await ai.main(user_id, user_query) |
| async for token in stream: |
| if token: |
| await _safe_text(ws, {"type": "llm_token", "token": token}) |
| except Exception as exc: |
| import traceback; traceback.print_exc() |
| await _safe_text(ws, {"type": "error", "text": str(exc)}) |
|
|
| await _safe_text(ws, {"type": "end"}) |
|
|
| except WebSocketDisconnect: |
| print("[CHAT] Client disconnected") |
| except Exception as exc: |
| if "disconnect" not in str(exc).lower(): |
| print(f"[CHAT] Error: {exc}") |
|
|
|
|
| |
| |
| |
|
|
| |
| _INIT_TIMEOUT = 3.0 |
|
|
|
|
| @app.websocket("/ws/voice") |
| async def ws_voice(ws: WebSocket): |
| await ws.accept() |
| print("[VOICE] Client connected") |
|
|
| |
| |
| |
| user_id: str = "" |
| try: |
| first_raw = await asyncio.wait_for(ws.receive(), timeout=_INIT_TIMEOUT) |
| if "text" in first_raw and first_raw["text"]: |
| try: |
| first_msg = json.loads(first_raw["text"]) |
| if first_msg.get("type") == "init": |
| claimed = str(first_msg.get("user_id", "")).strip()[:64] |
| if claimed: |
| user_id = claimed |
| except (json.JSONDecodeError, KeyError): |
| pass |
| except asyncio.TimeoutError: |
| print("[VOICE] No init message within timeout β using fallback user_id") |
|
|
| if not user_id: |
| user_id = f"voice_{uuid.uuid4().hex[:12]}" |
| print(f"[VOICE] Fallback user_id={user_id}") |
| else: |
| print(f"[VOICE] Session user_id={user_id}") |
|
|
| await _safe_text(ws, {"type": "init_ack", "user_id": user_id}) |
|
|
| stt = STTProcessor() |
| _active_streamer: ParallelTTSStreamer | None = None |
| _active_task: asyncio.Task | None = None |
|
|
| async def _cancel_active(): |
| nonlocal _active_streamer, _active_task |
| if _active_streamer is not None: |
| await _active_streamer.cancel() |
| _active_streamer = None |
| if _active_task is not None and not _active_task.done(): |
| _active_task.cancel() |
| try: |
| await _active_task |
| except (asyncio.CancelledError, Exception): |
| pass |
| _active_task = None |
|
|
| async def _handle_utterance(audio_bytes: bytes): |
| nonlocal _active_streamer |
|
|
| |
| transcript = await stt.transcribe(audio_bytes) |
| if not transcript: |
| await _safe_text(ws, {"type": "error", "text": "ΰ¦ΰ¦₯ΰ¦Ύ বΰ§ΰ¦ΰ¦€ΰ§ ΰ¦ͺারিনি, ΰ¦ΰ¦¬ΰ¦Ύΰ¦° বলΰ§ΰ¦¨ΰ₯€"}) |
| await _safe_text(ws, {"type": "end"}) |
| return |
|
|
| print(f"[VOICE] [{user_id}] STT: {transcript}") |
| if not await _safe_text(ws, {"type": "stt", "text": transcript}): |
| return |
|
|
| |
| tts_streamer = ParallelTTSStreamer() |
| _active_streamer = tts_streamer |
|
|
| async def run_llm(): |
| try: |
| stream = await ai.main(user_id, transcript) |
| async for token in stream: |
| if not token: |
| continue |
| if not await _safe_text(ws, {"type": "llm_token", "token": token}): |
| break |
| await tts_streamer.add_token(token) |
| except asyncio.CancelledError: |
| raise |
| except Exception as exc: |
| print(f"[VOICE] LLM error: {exc}") |
| finally: |
| await tts_streamer.flush() |
|
|
| async def run_tts(): |
| async for chunk in tts_streamer.stream_audio(): |
| if not await _safe_bytes(ws, chunk): |
| break |
|
|
| await asyncio.gather(run_llm(), run_tts(), return_exceptions=True) |
| _active_streamer = None |
| await _safe_text(ws, {"type": "end"}) |
|
|
| try: |
| while True: |
| if not _ws_open(ws): |
| break |
|
|
| try: |
| data = await ws.receive() |
| except WebSocketDisconnect: |
| break |
| except Exception as exc: |
| if "disconnect" in str(exc).lower(): |
| break |
| print(f"[VOICE] Receive error: {exc}") |
| break |
|
|
| if "bytes" in data and data["bytes"]: |
| audio_bytes = data["bytes"] |
| print(f"[VOICE] [{user_id}] Utterance: {len(audio_bytes):,} bytes") |
| await _cancel_active() |
| _active_task = asyncio.create_task(_handle_utterance(audio_bytes)) |
|
|
| elif "text" in data and data["text"]: |
| try: |
| msg = json.loads(data["text"]) |
| t = msg.get("type", "") |
| if t == "init": |
| |
| claimed = str(msg.get("user_id", "")).strip()[:64] |
| if claimed: |
| user_id = claimed |
| await _safe_text(ws, {"type": "init_ack", "user_id": user_id}) |
| elif t == "ping": |
| await _safe_text(ws, {"type": "pong"}) |
| elif t == "cancel": |
| await _cancel_active() |
| await _safe_text(ws, {"type": "end"}) |
| except json.JSONDecodeError: |
| pass |
|
|
| except WebSocketDisconnect: |
| pass |
| except Exception as exc: |
| if "disconnect" not in str(exc).lower(): |
| print(f"[VOICE] Error: {exc}") |
| finally: |
| await _cancel_active() |
| print(f"[VOICE] [{user_id}] Handler exiting cleanly.") |
|
|