File size: 16,599 Bytes
75ee53d
f84481c
 
5dabf9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75ee53d
 
ed5b8b8
 
 
75ee53d
f2ea5fc
ed5b8b8
f84481c
 
f2ea5fc
ed5b8b8
 
 
 
 
f2ea5fc
5dabf9d
f84481c
 
 
 
 
 
 
 
75ee53d
f84481c
75ee53d
5dabf9d
 
75ee53d
 
 
 
 
 
f84481c
75ee53d
 
 
 
 
 
 
f2ea5fc
f84481c
 
 
 
 
 
f2ea5fc
 
 
ed5b8b8
f84481c
f2ea5fc
f84481c
 
 
5dabf9d
 
 
 
 
 
ed5b8b8
f2ea5fc
 
 
ed5b8b8
 
 
 
 
 
 
 
 
 
 
 
 
5dabf9d
 
 
 
 
 
 
 
 
 
 
f84481c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dabf9d
f84481c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75ee53d
ed5b8b8
 
f2ea5fc
 
ed5b8b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f84481c
 
 
75ee53d
f2ea5fc
ed5b8b8
 
f84481c
5dabf9d
 
 
 
f2ea5fc
 
ed5b8b8
 
 
 
 
 
 
5dabf9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed5b8b8
 
 
 
f84481c
 
ed5b8b8
 
 
f84481c
 
b70a952
75ee53d
b70a952
ed5b8b8
 
 
f2ea5fc
ed5b8b8
b70a952
 
f84481c
ed5b8b8
f2ea5fc
f84481c
5dabf9d
f84481c
75ee53d
5dabf9d
 
 
 
f2ea5fc
ed5b8b8
 
5dabf9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed5b8b8
5dabf9d
75ee53d
f84481c
4d2289b
75ee53d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dabf9d
75ee53d
 
f84481c
75ee53d
 
 
 
 
 
 
f84481c
75ee53d
 
 
f84481c
75ee53d
 
 
 
 
 
 
 
 
 
 
f84481c
75ee53d
 
 
 
 
 
 
 
f84481c
75ee53d
 
ed5b8b8
f2ea5fc
 
ed5b8b8
 
 
 
 
 
 
b70a952
 
f84481c
 
ed5b8b8
 
 
 
75ee53d
 
f84481c
4d2289b
ed5b8b8
 
 
5dabf9d
 
 
 
 
 
75ee53d
5dabf9d
75ee53d
5dabf9d
75ee53d
b70a952
ed5b8b8
 
 
 
f84481c
b70a952
 
f84481c
ed5b8b8
75ee53d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
"""
app.py β€” FastAPI entrypoint: WebRTC-first + WebSocket fallback

FIXES APPLIED:
  FIX-SESSION (Issue 1): The voice WS handler now reads user_id from the
    first 'init' JSON message before processing any audio. The variable is
    no longer a random UUID per connection β€” it is the stable USER_ID sent
    by the browser from localStorage. This means every reconnect, even after
    a page reload, hits the same LangGraph thread and restores conversation
    history.

    Implementation:
      β€’ user_id is initialised to None inside ws_voice.
      β€’ The handler waits for any early text messages before processing binary.
      β€’ On 'init' message, user_id is set and init_ack returned.
      β€’ All subsequent audio/LLM calls use that stable user_id.
      β€’ If no 'init' is received within 3 s, a random fallback is used
        (prevents hang for non-browser clients).

  FIX-CHAT-INIT (Issue 1): ws_chat also reads the 'init' message so chat
    sessions share the same backend thread as voice sessions for the same
    user.

All performance optimisations (parallel TTS, GPU-batched STT, concurrent
LLM+TTS) preserved.
"""

import asyncio
import json
import os
import uuid
from contextlib import asynccontextmanager

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from starlette.websockets import WebSocketState

from core.backend import AIBackend
from services.stt import STTProcessor
from services.streaming import ParallelTTSStreamer

# ── WebRTC (optional) ─────────────────────────────────────────────────────────
try:
    from services.webrtc_pipeline import WebRTCSession
    WEBRTC_AVAILABLE = True
    print("[APP] WebRTC pipeline available βœ“")
except (ImportError, RuntimeError) as _e:
    WEBRTC_AVAILABLE = False
    print(f"[APP] WebRTC pipeline unavailable ({_e}). WebSocket fallback only.")

# ══════════════════════════════════════════════════════════════════════════════
#  MODEL ROUTING CONFIG β€” set exactly ONE to True
# ══════════════════════════════════════════════════════════════════════════════
USE_GEMINI         = False
USE_OLLAMA         = True
USE_LOCAL_FALLBACK = False

_active = sum([USE_GEMINI, USE_OLLAMA, USE_LOCAL_FALLBACK])
if _active != 1:
    raise RuntimeError(
        f"[CONFIG] Exactly one of USE_GEMINI / USE_OLLAMA / USE_LOCAL_FALLBACK "
        f"must be True. Got {_active}."
    )

ai = AIBackend(
    use_gemini=USE_GEMINI,
    use_ollama=USE_OLLAMA,
    use_fallback=USE_LOCAL_FALLBACK,
)

_rtc_sessions: dict[str, "WebRTCSession"] = {}


# ══════════════════════════════════════════════════════════════════════════════
#  LIFESPAN
# ══════════════════════════════════════════════════════════════════════════════

@asynccontextmanager
async def lifespan(app: FastAPI):
    await ai.async_setup()
    print("[APP] AI backend ready βœ“")
    yield
    for session in list(_rtc_sessions.values()):
        await session.close()
    _rtc_sessions.clear()
    conn = getattr(ai, "conn", None)
    if conn:
        try:
            await conn.close()
        except Exception:
            pass


app = FastAPI(lifespan=lifespan)

try:
    app.mount("/static", StaticFiles(directory="."), name="static")
except Exception:
    pass


@app.get("/")
async def root():
    if os.path.exists("index.html"):
        return FileResponse("index.html")
    return HTMLResponse("<h2>index.html not found</h2>", status_code=404)


@app.get("/health")
async def health():
    from services.stt import _model_ready, _model_error
    return JSONResponse({
        "status":       "ok",
        "model_ready":  _model_ready.is_set(),
        "model_error":  _model_error,
        "rtc_sessions": len(_rtc_sessions),
    })


# ══════════════════════════════════════════════════════════════════════════════
#  WEBRTC SIGNALING ENDPOINTS
# ══════════════════════════════════════════════════════════════════════════════

@app.post("/rtc/offer")
async def rtc_offer(request: Request):
    if not WEBRTC_AVAILABLE:
        return JSONResponse(
            {"error": "WebRTC unavailable. Use WebSocket fallback at /ws/voice"},
            status_code=503,
        )
    body       = await request.json()
    sdp        = body.get("sdp", "")
    sdp_type   = body.get("type", "offer")
    session_id = body.get("session_id") or uuid.uuid4().hex

    session = _rtc_sessions.get(session_id)
    if session is None:
        session = WebRTCSession(ai_backend=ai)
        _rtc_sessions[session_id] = session
        print(f"[RTC] New session: {session_id} user_id={session.user_id}")

    answer = await session.handle_offer(sdp, sdp_type)
    return JSONResponse({**answer, "session_id": session_id})


@app.post("/rtc/ice")
async def rtc_ice(request: Request):
    if not WEBRTC_AVAILABLE:
        return JSONResponse({"error": "WebRTC unavailable"}, status_code=503)
    body       = await request.json()
    session_id = body.get("session_id", "")
    candidate  = body.get("candidate", {})
    session    = _rtc_sessions.get(session_id)
    if session is None:
        return JSONResponse({"error": "Session not found"}, status_code=404)
    await session.add_ice_candidate(candidate)
    return JSONResponse({"ok": True})


@app.delete("/rtc/session/{session_id}")
async def rtc_close(session_id: str):
    session = _rtc_sessions.pop(session_id, None)
    if session:
        await session.close()
    return JSONResponse({"ok": True})


# ══════════════════════════════════════════════════════════════════════════════
#  WEBSOCKET HELPERS
# ══════════════════════════════════════════════════════════════════════════════

def _ws_open(ws: WebSocket) -> bool:
    return ws.client_state == WebSocketState.CONNECTED


async def _safe_text(ws: WebSocket, payload: dict) -> bool:
    if not _ws_open(ws):
        return False
    try:
        await ws.send_text(json.dumps(payload))
        return True
    except Exception:
        return False


async def _safe_bytes(ws: WebSocket, data: bytes) -> bool:
    if not _ws_open(ws):
        return False
    try:
        await ws.send_bytes(data)
        return True
    except Exception:
        return False


# ══════════════════════════════════════════════════════════════════════════════
#  WEBSOCKET β€” CHAT (text only, streaming tokens)
# ══════════════════════════════════════════════════════════════════════════════

@app.websocket("/ws/chat")
async def ws_chat(ws: WebSocket):
    await ws.accept()
    print("[CHAT] Client connected βœ“")

    # FIX-SESSION: Start with no user_id; wait for 'init' to set it.
    user_id: str = ""

    try:
        while True:
            raw = await ws.receive_text()
            try:
                data = json.loads(raw)
            except json.JSONDecodeError:
                await _safe_text(ws, {"type": "error", "text": "Invalid JSON"})
                continue

            msg_type = data.get("type", "")

            # ── Init handshake ──────────────────────────────────────────────
            if msg_type == "init":
                claimed = str(data.get("user_id", "")).strip()[:64]
                if claimed:
                    user_id = claimed
                    print(f"[CHAT] Session restored for user_id={user_id!r}")
                    await _safe_text(ws, {"type": "init_ack", "user_id": user_id})
                continue

            if msg_type == "ping":
                await _safe_text(ws, {"type": "pong"})
                continue

            # Fall back to user_id in message payload (compatibility)
            if not user_id:
                user_id = str(data.get("user_id", "default_user"))[:64]

            user_query = data.get("user_query", "").strip()
            if not user_query:
                continue

            print(f"[CHAT] user_id={user_id!r} query={user_query!r}")

            try:
                stream = await ai.main(user_id, user_query)
                async for token in stream:
                    if token:
                        await _safe_text(ws, {"type": "llm_token", "token": token})
            except Exception as exc:
                import traceback; traceback.print_exc()
                await _safe_text(ws, {"type": "error", "text": str(exc)})

            await _safe_text(ws, {"type": "end"})

    except WebSocketDisconnect:
        print("[CHAT] Client disconnected")
    except Exception as exc:
        if "disconnect" not in str(exc).lower():
            print(f"[CHAT] Error: {exc}")


# ══════════════════════════════════════════════════════════════════════════════
#  WEBSOCKET — VOICE (STT→LLM→TTS pipeline over WS)
# ══════════════════════════════════════════════════════════════════════════════

# How long (seconds) to wait for the first 'init' message before using fallback
_INIT_TIMEOUT = 3.0


@app.websocket("/ws/voice")
async def ws_voice(ws: WebSocket):
    await ws.accept()
    print("[VOICE] Client connected")

    # ── FIX-SESSION: Resolve stable user_id from browser init message ────────
    # Wait up to _INIT_TIMEOUT seconds for the {'type':'init','user_id':...} msg.
    # This is always the FIRST message sent by script.js on WS open.
    user_id: str = ""
    try:
        first_raw = await asyncio.wait_for(ws.receive(), timeout=_INIT_TIMEOUT)
        if "text" in first_raw and first_raw["text"]:
            try:
                first_msg = json.loads(first_raw["text"])
                if first_msg.get("type") == "init":
                    claimed = str(first_msg.get("user_id", "")).strip()[:64]
                    if claimed:
                        user_id = claimed
            except (json.JSONDecodeError, KeyError):
                pass
    except asyncio.TimeoutError:
        print("[VOICE] No init message within timeout β€” using fallback user_id")

    if not user_id:
        user_id = f"voice_{uuid.uuid4().hex[:12]}"
        print(f"[VOICE] Fallback user_id={user_id}")
    else:
        print(f"[VOICE] Session user_id={user_id}")

    await _safe_text(ws, {"type": "init_ack", "user_id": user_id})

    stt = STTProcessor()
    _active_streamer: ParallelTTSStreamer | None = None
    _active_task:     asyncio.Task | None        = None

    async def _cancel_active():
        nonlocal _active_streamer, _active_task
        if _active_streamer is not None:
            await _active_streamer.cancel()
            _active_streamer = None
        if _active_task is not None and not _active_task.done():
            _active_task.cancel()
            try:
                await _active_task
            except (asyncio.CancelledError, Exception):
                pass
            _active_task = None

    async def _handle_utterance(audio_bytes: bytes):
        nonlocal _active_streamer

        # ── STT ───────────────────────────────────────────────────────────────
        transcript = await stt.transcribe(audio_bytes)
        if not transcript:
            await _safe_text(ws, {"type": "error", "text": "কΰ¦₯ΰ¦Ύ বুঝঀে ΰ¦ͺারিনি, আবার বলুনΰ₯€"})
            await _safe_text(ws, {"type": "end"})
            return

        print(f"[VOICE] [{user_id}] STT: {transcript}")
        if not await _safe_text(ws, {"type": "stt", "text": transcript}):
            return

        # ── LLM + TTS (concurrent) ─────────────────────────────────────────────
        tts_streamer     = ParallelTTSStreamer()
        _active_streamer = tts_streamer

        async def run_llm():
            try:
                stream = await ai.main(user_id, transcript)
                async for token in stream:
                    if not token:
                        continue
                    if not await _safe_text(ws, {"type": "llm_token", "token": token}):
                        break
                    await tts_streamer.add_token(token)
            except asyncio.CancelledError:
                raise
            except Exception as exc:
                print(f"[VOICE] LLM error: {exc}")
            finally:
                await tts_streamer.flush()

        async def run_tts():
            async for chunk in tts_streamer.stream_audio():
                if not await _safe_bytes(ws, chunk):
                    break

        await asyncio.gather(run_llm(), run_tts(), return_exceptions=True)
        _active_streamer = None
        await _safe_text(ws, {"type": "end"})

    try:
        while True:
            if not _ws_open(ws):
                break

            try:
                data = await ws.receive()
            except WebSocketDisconnect:
                break
            except Exception as exc:
                if "disconnect" in str(exc).lower():
                    break
                print(f"[VOICE] Receive error: {exc}")
                break

            if "bytes" in data and data["bytes"]:
                audio_bytes = data["bytes"]
                print(f"[VOICE] [{user_id}] Utterance: {len(audio_bytes):,} bytes")
                await _cancel_active()
                _active_task = asyncio.create_task(_handle_utterance(audio_bytes))

            elif "text" in data and data["text"]:
                try:
                    msg = json.loads(data["text"])
                    t   = msg.get("type", "")
                    if t == "init":
                        # Late re-init (e.g. after reconnect with same WS obj β€” rare)
                        claimed = str(msg.get("user_id", "")).strip()[:64]
                        if claimed:
                            user_id = claimed
                        await _safe_text(ws, {"type": "init_ack", "user_id": user_id})
                    elif t == "ping":
                        await _safe_text(ws, {"type": "pong"})
                    elif t == "cancel":
                        await _cancel_active()
                        await _safe_text(ws, {"type": "end"})
                except json.JSONDecodeError:
                    pass

    except WebSocketDisconnect:
        pass
    except Exception as exc:
        if "disconnect" not in str(exc).lower():
            print(f"[VOICE] Error: {exc}")
    finally:
        await _cancel_active()
        print(f"[VOICE] [{user_id}] Handler exiting cleanly.")