Commit ·
4d2289b
1
Parent(s): b70a952
fluent communication part :: rakib
Browse files- .env +2 -0
- app.py +8 -27
- core/backend.py +272 -251
- frontend/index.html +1 -0
- frontend/script.js +97 -102
- services/streaming.py +178 -113
.env
CHANGED
|
@@ -5,6 +5,8 @@ LANGCHAIN_ENDPOINT='https://api.smith.langchain.com'
|
|
| 5 |
LANGCHAIN_API_KEY='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'
|
| 6 |
LANGCHAIN_PROJECT='Default'
|
| 7 |
|
|
|
|
|
|
|
| 8 |
# TWILIO_ACCOUNT_SID="ACfafc0d2d007bdf14b21bb3e14a7a7b31"
|
| 9 |
# TWILIO_AUTH_TOKEN="ed15fa98748c8c3d3d02cb54e431a187"
|
| 10 |
# TWILIO_PHONE_NUMBER="+14343375085"
|
|
|
|
| 5 |
LANGCHAIN_API_KEY='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'
|
| 6 |
LANGCHAIN_PROJECT='Default'
|
| 7 |
|
| 8 |
+
GOOGLE_API_KEY="AIzaSyA9sqz4YKQHKXR9TU1imw0DPOghzHOMiBo"
|
| 9 |
+
|
| 10 |
# TWILIO_ACCOUNT_SID="ACfafc0d2d007bdf14b21bb3e14a7a7b31"
|
| 11 |
# TWILIO_AUTH_TOKEN="ed15fa98748c8c3d3d02cb54e431a187"
|
| 12 |
# TWILIO_PHONE_NUMBER="+14343375085"
|
app.py
CHANGED
|
@@ -1,19 +1,3 @@
|
|
| 1 |
-
"""
|
| 2 |
-
app.py — FastAPI entry point
|
| 3 |
-
|
| 4 |
-
Fixes applied
|
| 5 |
-
─────────────
|
| 6 |
-
1. STT is now fully async (stt.transcribe is a coroutine) — no more
|
| 7 |
-
asyncio.to_thread wrapper needed in the WS handler.
|
| 8 |
-
2. BARGE-IN: when the client sends a new audio blob while TTS is still
|
| 9 |
-
playing, the running tts_streamer is cancelled before starting a new
|
| 10 |
-
turn. The client enforces isProcessing so this should be rare, but
|
| 11 |
-
the server now handles it gracefully.
|
| 12 |
-
3. Per-session cancel token stored in `_active_streamer` so any new
|
| 13 |
-
utterance from the same WS cleanly aborts the previous one.
|
| 14 |
-
4. All other logic (ping/pong, safe send helpers, chat WS) is unchanged.
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
import asyncio
|
| 18 |
import json
|
| 19 |
import os
|
|
@@ -55,7 +39,6 @@ async def root():
|
|
| 55 |
return HTMLResponse("<h2>index.html not found</h2>", status_code=404)
|
| 56 |
|
| 57 |
|
| 58 |
-
# ── Helpers ────────────────────────────────────────────────────────────────────
|
| 59 |
def _ws_open(ws: WebSocket) -> bool:
|
| 60 |
return ws.client_state == WebSocketState.CONNECTED
|
| 61 |
|
|
@@ -80,7 +63,6 @@ async def _safe_bytes(ws: WebSocket, data: bytes) -> bool:
|
|
| 80 |
return False
|
| 81 |
|
| 82 |
|
| 83 |
-
# ── Text chat WebSocket ────────────────────────────────────────────────────────
|
| 84 |
@app.websocket("/ws/chat")
|
| 85 |
async def ws_chat(ws: WebSocket):
|
| 86 |
await ws.accept()
|
|
@@ -118,7 +100,6 @@ async def ws_chat(ws: WebSocket):
|
|
| 118 |
print(f"[CHAT] WS error: {exc}")
|
| 119 |
|
| 120 |
|
| 121 |
-
# ── Voice WebSocket ────────────────────────────────────────────────────────────
|
| 122 |
@app.websocket("/ws/voice")
|
| 123 |
async def ws_voice(ws: WebSocket):
|
| 124 |
await ws.accept()
|
|
@@ -126,7 +107,7 @@ async def ws_voice(ws: WebSocket):
|
|
| 126 |
|
| 127 |
stt = STTProcessor()
|
| 128 |
user_id = "voice_user"
|
| 129 |
-
_active_streamer: ParallelTTSStreamer | None = None
|
| 130 |
|
| 131 |
try:
|
| 132 |
while True:
|
|
@@ -146,18 +127,18 @@ async def ws_voice(ws: WebSocket):
|
|
| 146 |
print(f"[VOICE] Receive error: {exc}")
|
| 147 |
break
|
| 148 |
|
| 149 |
-
|
| 150 |
if "bytes" in data and data["bytes"]:
|
| 151 |
audio_bytes = data["bytes"]
|
| 152 |
print(f"[VOICE] Received utterance: {len(audio_bytes):,} bytes")
|
| 153 |
|
| 154 |
-
|
| 155 |
if _active_streamer is not None:
|
| 156 |
print("[VOICE] Barge-in — cancelling previous TTS.")
|
| 157 |
await _active_streamer.cancel()
|
| 158 |
_active_streamer = None
|
| 159 |
|
| 160 |
-
|
| 161 |
transcript = await stt.transcribe(audio_bytes)
|
| 162 |
|
| 163 |
if not transcript:
|
|
@@ -172,7 +153,7 @@ async def ws_voice(ws: WebSocket):
|
|
| 172 |
if not await _safe_text(ws, {"type": "stt", "text": transcript}):
|
| 173 |
break
|
| 174 |
|
| 175 |
-
|
| 176 |
tts_streamer = ParallelTTSStreamer()
|
| 177 |
_active_streamer = tts_streamer
|
| 178 |
|
|
@@ -198,17 +179,17 @@ async def ws_voice(ws: WebSocket):
|
|
| 198 |
await asyncio.gather(run_ai_and_tts(), stream_tts_audio())
|
| 199 |
_active_streamer = None
|
| 200 |
|
| 201 |
-
|
| 202 |
await _safe_text(ws, {"type": "end"})
|
| 203 |
|
| 204 |
-
|
| 205 |
elif "text" in data and data["text"]:
|
| 206 |
try:
|
| 207 |
msg = json.loads(data["text"])
|
| 208 |
if msg.get("type") == "ping":
|
| 209 |
await _safe_text(ws, {"type": "pong"})
|
| 210 |
|
| 211 |
-
|
| 212 |
elif msg.get("type") == "cancel":
|
| 213 |
if _active_streamer is not None:
|
| 214 |
print("[VOICE] Client cancel signal received.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
import os
|
|
|
|
| 39 |
return HTMLResponse("<h2>index.html not found</h2>", status_code=404)
|
| 40 |
|
| 41 |
|
|
|
|
| 42 |
def _ws_open(ws: WebSocket) -> bool:
|
| 43 |
return ws.client_state == WebSocketState.CONNECTED
|
| 44 |
|
|
|
|
| 63 |
return False
|
| 64 |
|
| 65 |
|
|
|
|
| 66 |
@app.websocket("/ws/chat")
|
| 67 |
async def ws_chat(ws: WebSocket):
|
| 68 |
await ws.accept()
|
|
|
|
| 100 |
print(f"[CHAT] WS error: {exc}")
|
| 101 |
|
| 102 |
|
|
|
|
| 103 |
@app.websocket("/ws/voice")
|
| 104 |
async def ws_voice(ws: WebSocket):
|
| 105 |
await ws.accept()
|
|
|
|
| 107 |
|
| 108 |
stt = STTProcessor()
|
| 109 |
user_id = "voice_user"
|
| 110 |
+
_active_streamer: ParallelTTSStreamer | None = None
|
| 111 |
|
| 112 |
try:
|
| 113 |
while True:
|
|
|
|
| 127 |
print(f"[VOICE] Receive error: {exc}")
|
| 128 |
break
|
| 129 |
|
| 130 |
+
|
| 131 |
if "bytes" in data and data["bytes"]:
|
| 132 |
audio_bytes = data["bytes"]
|
| 133 |
print(f"[VOICE] Received utterance: {len(audio_bytes):,} bytes")
|
| 134 |
|
| 135 |
+
|
| 136 |
if _active_streamer is not None:
|
| 137 |
print("[VOICE] Barge-in — cancelling previous TTS.")
|
| 138 |
await _active_streamer.cancel()
|
| 139 |
_active_streamer = None
|
| 140 |
|
| 141 |
+
|
| 142 |
transcript = await stt.transcribe(audio_bytes)
|
| 143 |
|
| 144 |
if not transcript:
|
|
|
|
| 153 |
if not await _safe_text(ws, {"type": "stt", "text": transcript}):
|
| 154 |
break
|
| 155 |
|
| 156 |
+
|
| 157 |
tts_streamer = ParallelTTSStreamer()
|
| 158 |
_active_streamer = tts_streamer
|
| 159 |
|
|
|
|
| 179 |
await asyncio.gather(run_ai_and_tts(), stream_tts_audio())
|
| 180 |
_active_streamer = None
|
| 181 |
|
| 182 |
+
|
| 183 |
await _safe_text(ws, {"type": "end"})
|
| 184 |
|
| 185 |
+
|
| 186 |
elif "text" in data and data["text"]:
|
| 187 |
try:
|
| 188 |
msg = json.loads(data["text"])
|
| 189 |
if msg.get("type") == "ping":
|
| 190 |
await _safe_text(ws, {"type": "pong"})
|
| 191 |
|
| 192 |
+
|
| 193 |
elif msg.get("type") == "cancel":
|
| 194 |
if _active_streamer is not None:
|
| 195 |
print("[VOICE] Client cancel signal received.")
|
core/backend.py
CHANGED
|
@@ -1,36 +1,43 @@
|
|
| 1 |
-
from
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
| 6 |
-
from
|
|
|
|
| 7 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 8 |
-
from langchain_community.tools import DuckDuckGoSearchRun
|
| 9 |
-
from langchain_core.tools import tool
|
| 10 |
-
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, RemoveMessage, SystemMessage
|
| 11 |
-
import aiosqlite, uuid, os, httpx, asyncio
|
| 12 |
from twilio.rest import Client
|
| 13 |
-
from
|
| 14 |
-
import json, pytz
|
| 15 |
-
from datetime import datetime
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
class ChatState(TypedDict):
|
| 19 |
-
messages: Annotated[list
|
| 20 |
summary: str
|
| 21 |
|
| 22 |
-
|
| 23 |
-
#
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
return os.path.join(os.path.dirname(__file__), "daa.db")
|
| 26 |
|
| 27 |
-
def send_sms(to_number: str, message: str):
|
| 28 |
-
client = Client(os.getenv("TWILIO_ACCOUNT_SID"), os.getenv("TWILIO_AUTH_TOKEN"))
|
| 29 |
-
client.messages.create(
|
| 30 |
-
body=message,
|
| 31 |
-
from_=os.getenv("TWILIO_PHONE_NUMBER"),
|
| 32 |
-
to=to_number
|
| 33 |
-
)
|
| 34 |
|
| 35 |
def format_bd_number(num: str) -> str:
|
| 36 |
num = num.strip().replace(" ", "")
|
|
@@ -38,36 +45,50 @@ def format_bd_number(num: str) -> str:
|
|
| 38 |
return "+88" + num
|
| 39 |
if num.startswith("8801"):
|
| 40 |
return "+" + num
|
| 41 |
-
return num
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
|
|
|
|
|
|
|
|
|
| 43 |
@tool
|
| 44 |
def get_bd_time() -> str:
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
"""
|
| 48 |
-
tz = pytz.timezone("Asia/Dhaka")
|
| 49 |
now = datetime.now(tz)
|
| 50 |
return now.strftime("%Y-%m-%d %H:%M:%S (%A, Bangladesh Time)")
|
| 51 |
|
|
|
|
| 52 |
@tool
|
| 53 |
-
async def search_doctor(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
"""
|
| 55 |
-
Search doctors by name, category, or visiting_days from SQLite database.
|
| 56 |
-
Any combination of filters is supported (OR logic
|
| 57 |
"""
|
| 58 |
-
db_path
|
| 59 |
-
query
|
| 60 |
-
params = []
|
| 61 |
-
conditions = []
|
| 62 |
|
| 63 |
if name:
|
| 64 |
conditions.append("LOWER(doctor_name) LIKE ?")
|
| 65 |
params.append(f"%{name.lower()}%")
|
| 66 |
-
|
| 67 |
if category:
|
| 68 |
conditions.append("LOWER(category) LIKE ?")
|
| 69 |
params.append(f"%{category.lower()}%")
|
| 70 |
-
|
| 71 |
if visiting_days:
|
| 72 |
conditions.append("LOWER(visiting_days) LIKE ?")
|
| 73 |
params.append(f"%{visiting_days.lower()}%")
|
|
@@ -78,119 +99,89 @@ async def search_doctor(name: str = "", category: str = "", visiting_days: str =
|
|
| 78 |
async with aiosqlite.connect(db_path) as db:
|
| 79 |
db.row_factory = aiosqlite.Row
|
| 80 |
cursor = await db.execute(query, params)
|
| 81 |
-
rows
|
| 82 |
|
| 83 |
if not rows:
|
| 84 |
-
return json.dumps({
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
"data": []
|
| 88 |
-
})
|
| 89 |
|
| 90 |
-
return json.dumps({
|
| 91 |
-
"success": True,
|
| 92 |
-
"count": len(rows),
|
| 93 |
-
"data": [dict(r) for r in rows]
|
| 94 |
-
})
|
| 95 |
|
| 96 |
@tool
|
| 97 |
async def search_appointment_by_phone(patient_num: str) -> str:
|
| 98 |
-
"""
|
| 99 |
-
|
| 100 |
-
"""
|
| 101 |
-
db_path = get_db_path()
|
| 102 |
patient_num = format_bd_number(patient_num)
|
| 103 |
|
| 104 |
async with aiosqlite.connect(db_path) as db:
|
| 105 |
db.row_factory = aiosqlite.Row
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
ORDER BY visiting_date ASC
|
| 111 |
-
""", (patient_num,))
|
| 112 |
-
|
| 113 |
rows = await cursor.fetchall()
|
| 114 |
|
| 115 |
if not rows:
|
| 116 |
return json.dumps({
|
| 117 |
"success": False,
|
| 118 |
"message": "No appointments found for this phone number.",
|
| 119 |
-
"data": []
|
| 120 |
})
|
|
|
|
| 121 |
|
| 122 |
-
return json.dumps({
|
| 123 |
-
"success": True,
|
| 124 |
-
"count": len(rows),
|
| 125 |
-
"data": [dict(r) for r in rows]
|
| 126 |
-
})
|
| 127 |
|
| 128 |
@tool
|
| 129 |
-
async def book_appointment(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
"""
|
| 131 |
Book a doctor appointment and save it to the patients table.
|
| 132 |
|
| 133 |
Args:
|
| 134 |
-
doctor_id:
|
| 135 |
-
patient_name:
|
| 136 |
-
patient_age:
|
| 137 |
-
patient_num:
|
| 138 |
visiting_date: Date of visit in YYYY-MM-DD format (e.g. 2025-06-15).
|
| 139 |
-
|
| 140 |
-
Returns a booking confirmation with the new record ID.
|
| 141 |
"""
|
| 142 |
-
db_path
|
|
|
|
| 143 |
|
| 144 |
async with aiosqlite.connect(db_path) as db:
|
| 145 |
db.row_factory = aiosqlite.Row
|
| 146 |
|
| 147 |
-
patient_num = format_bd_number(patient_num)
|
| 148 |
-
|
| 149 |
-
# Verify doctor exists
|
| 150 |
cursor = await db.execute("SELECT * FROM doctors WHERE id = ?", (doctor_id,))
|
| 151 |
doctor = await cursor.fetchone()
|
| 152 |
if not doctor:
|
| 153 |
return f"No doctor found with ID {doctor_id}. Please search for a doctor first."
|
| 154 |
|
| 155 |
-
doctor_data
|
| 156 |
-
doctor_name
|
| 157 |
-
doctor_category = doctor_data.get("
|
| 158 |
|
| 159 |
-
# Check for conflicting booking (same doctor + same date)
|
| 160 |
cursor = await db.execute(
|
| 161 |
"""SELECT id FROM patients
|
| 162 |
-
|
| 163 |
(doctor_name, visiting_date, patient_num),
|
| 164 |
)
|
| 165 |
-
|
| 166 |
-
if conflict:
|
| 167 |
return (
|
| 168 |
f"A booking for {patient_name} with Dr. {doctor_name} "
|
| 169 |
f"on {visiting_date} already exists."
|
| 170 |
)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
(doctor_name, doctor_category, patient_name, patient_age, patient_num, visiting_date),
|
| 177 |
)
|
| 178 |
await db.commit()
|
| 179 |
|
| 180 |
-
# Send SMS confirmation
|
| 181 |
-
sms_message = (
|
| 182 |
-
f"✅ Appointment Confirmed!\n"
|
| 183 |
-
f"Doctor : {doctor_name}\n"
|
| 184 |
-
f"Patient : {patient_name}\n"
|
| 185 |
-
f"Visit Date : {visiting_date}\n"
|
| 186 |
-
f"Please arrive 10 minutes early."
|
| 187 |
-
)
|
| 188 |
-
# try:
|
| 189 |
-
# send_sms(to_number=patient_num, message=sms_message)
|
| 190 |
-
# sms_status = "📱 SMS confirmation sent."
|
| 191 |
-
# except Exception as e:
|
| 192 |
-
# sms_status = f"⚠️ SMS failed: {str(e)}"
|
| 193 |
-
|
| 194 |
return (
|
| 195 |
f"✅ Appointment Booked!\n"
|
| 196 |
f"━━━━━━━━━━━━━━━━━━━━━━\n"
|
|
@@ -201,182 +192,203 @@ async def book_appointment(doctor_id: int, patient_name: str, patient_age: str,
|
|
| 201 |
f"Contact : {patient_num}\n"
|
| 202 |
f"━━━━━━━━━━━━━━━━━━━━━━\n"
|
| 203 |
f"Please arrive 10 minutes early."
|
| 204 |
-
# f"{sms_status}"
|
| 205 |
)
|
| 206 |
|
|
|
|
|
|
|
| 207 |
async def delete_appointment(patient_num: str, doctor_name: str) -> str:
|
| 208 |
-
"""
|
| 209 |
-
|
| 210 |
-
"""
|
| 211 |
-
db_path = get_db_path()
|
| 212 |
-
# normalize phone number
|
| 213 |
patient_num = format_bd_number(patient_num)
|
| 214 |
|
| 215 |
async with aiosqlite.connect(db_path) as db:
|
| 216 |
db.row_factory = aiosqlite.Row
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
row = await cursor.fetchone()
|
| 226 |
-
if not row:
|
| 227 |
-
return json.dumps({
|
| 228 |
-
"success": False,
|
| 229 |
-
"message": "No matching appointment found to delete."
|
| 230 |
-
})
|
| 231 |
-
|
| 232 |
-
# delete appointment
|
| 233 |
-
await db.execute("""
|
| 234 |
-
DELETE FROM patients
|
| 235 |
-
WHERE patient_num = ?
|
| 236 |
-
AND LOWER(doctor_name) = LOWER(?)
|
| 237 |
-
""", (patient_num, doctor_name))
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
await db.commit()
|
| 240 |
|
| 241 |
return json.dumps({
|
| 242 |
"success": True,
|
| 243 |
-
"message": f"Appointment with Dr. {doctor_name} deleted successfully."
|
| 244 |
})
|
| 245 |
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
class AIBackend:
|
| 248 |
-
|
|
|
|
| 249 |
load_dotenv()
|
| 250 |
-
os.environ
|
| 251 |
-
|
| 252 |
-
self.
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
self.llm_with_tools = self.llm.bind_tools(self.tools)
|
| 255 |
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
| 259 |
self.checkpointer = AsyncSqliteSaver(self.conn)
|
| 260 |
-
await self.
|
| 261 |
-
self.graph
|
| 262 |
self.summary_graph = self._build_summary_graph()
|
| 263 |
|
| 264 |
-
async def
|
| 265 |
await self.conn.execute("""
|
| 266 |
CREATE TABLE IF NOT EXISTS userid_threadid (
|
| 267 |
-
userId
|
| 268 |
threadId TEXT UNIQUE NOT NULL
|
| 269 |
)
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
await self.conn.commit()
|
| 272 |
|
| 273 |
-
#
|
| 274 |
async def summarize_conversation(self, state: ChatState):
|
| 275 |
-
|
| 276 |
messages = state["messages"]
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
- Write the summary in clear bullet points.
|
| 294 |
-
- Focus on information useful for future conversations and contextual continuity.
|
| 295 |
-
- Do NOT include casual greetings or temporary small talk unless important.
|
| 296 |
-
- Keep the summary compact but information-dense.
|
| 297 |
-
"""
|
| 298 |
-
if existing_summary
|
| 299 |
-
else
|
| 300 |
-
"""
|
| 301 |
-
You are creating a long-term conversation memory summary for a chatbot.
|
| 302 |
-
|
| 303 |
-
Summarize the conversation above.
|
| 304 |
-
|
| 305 |
-
Instructions:
|
| 306 |
-
- Capture important user information, goals, preferences, projects, and decisions.
|
| 307 |
-
- Include technical issues, debugging progress, and solutions discussed.
|
| 308 |
-
- Track ongoing tasks or unresolved questions.
|
| 309 |
-
- Ignore casual greetings and low-value chatter.
|
| 310 |
-
- Write concise, structured bullet points.
|
| 311 |
-
- Keep the summary compact but highly informative for future context retention.
|
| 312 |
-
"""
|
| 313 |
-
)
|
| 314 |
-
messages_for_summary = messages + [HumanMessage(content=prompt)]
|
| 315 |
-
response = await self.llm.ainvoke(messages_for_summary)
|
| 316 |
return {
|
| 317 |
"summary": response.content,
|
| 318 |
"messages": [RemoveMessage(id=m.id) for m in messages[:-2]],
|
| 319 |
}
|
| 320 |
|
| 321 |
-
async def should_summarize(self, state: ChatState):
|
| 322 |
-
if len(state["messages"]) > 10
|
| 323 |
-
return "summarize_node"
|
| 324 |
-
return "chat_node"
|
| 325 |
|
| 326 |
-
#
|
| 327 |
async def chat_node(self, state: ChatState):
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
messages = state["messages"]
|
| 330 |
|
| 331 |
-
print(
|
| 332 |
print(">>>>>>>>>> CHAT NODE START <<<<<<<<<<")
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
else:
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
-
print(
|
| 339 |
-
print("[MESSAGES]:")
|
| 340 |
-
for m in messages:
|
| 341 |
-
role = m.__class__.__name__
|
| 342 |
-
print(f" [{role}]: {m.content[:200]}")
|
| 343 |
-
print('$'*50,'\n')
|
| 344 |
-
|
| 345 |
-
if summary:
|
| 346 |
-
summary_message = SystemMessage(
|
| 347 |
-
content=(
|
| 348 |
-
"You are a Bangla voice assistant. You are provided with a condensed memory of previous conversations.\n\n"
|
| 349 |
-
f"Conversation Memory:\n{summary}\n\n"
|
| 350 |
-
"Instructions:\n"
|
| 351 |
-
"- Always respond in Bangla (বাংলা)"
|
| 352 |
-
"- Keep sentences short for speech"
|
| 353 |
-
"- No English unless necessary"
|
| 354 |
-
"- Use this memory as long-term conversational context.\n"
|
| 355 |
-
"- Maintain continuity with the user's previous discussions, projects, goals, and preferences.\n"
|
| 356 |
-
"- Prioritize recent and relevant information when generating responses.\n"
|
| 357 |
-
"- Do not repeat the summary unless necessary.\n"
|
| 358 |
-
"- If new information conflicts with old memory, prefer the latest context.\n"
|
| 359 |
-
"- Use the memory naturally to improve personalization, reasoning, and follow-up responses.\n"
|
| 360 |
-
"- Treat unresolved issues, active projects, and pending tasks as ongoing unless stated otherwise."
|
| 361 |
-
)
|
| 362 |
-
)
|
| 363 |
-
messages = [summary_message] + messages
|
| 364 |
-
response = await self.llm_with_tools.ainvoke(messages)
|
| 365 |
-
print(f"Final [{response.__class__.__name__}]: {response.content[:200]}")
|
| 366 |
print(">>>>>>>>>> CHAT NODE END <<<<<<<<<<")
|
| 367 |
-
print('#'*50)
|
| 368 |
return {"messages": [response]}
|
| 369 |
|
| 370 |
-
#
|
| 371 |
def _build_graph(self):
|
| 372 |
g = StateGraph(ChatState)
|
| 373 |
g.add_node("chat_node", self.chat_node)
|
| 374 |
-
g.add_node("tools",
|
| 375 |
-
|
| 376 |
g.add_edge(START, "chat_node")
|
| 377 |
g.add_conditional_edges("chat_node", tools_condition)
|
| 378 |
g.add_edge("tools", "chat_node")
|
| 379 |
-
|
| 380 |
return g.compile(checkpointer=self.checkpointer)
|
| 381 |
|
| 382 |
def _build_summary_graph(self):
|
|
@@ -386,40 +398,49 @@ class AIBackend:
|
|
| 386 |
g.add_edge("summarize_node", END)
|
| 387 |
return g.compile(checkpointer=self.checkpointer)
|
| 388 |
|
| 389 |
-
#
|
| 390 |
async def ai_only_stream(self, initial_state: dict, config: dict):
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
asyncio.create_task(
|
| 399 |
-
self.summary_graph.ainvoke(
|
| 400 |
)
|
| 401 |
-
print(
|
| 402 |
|
| 403 |
-
#
|
| 404 |
@staticmethod
|
| 405 |
def generate_thread_id() -> str:
|
| 406 |
return str(uuid.uuid4())
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
return list(all_threads)
|
| 414 |
|
| 415 |
-
#
|
| 416 |
async def main(self, user_id: str, user_query: str):
|
|
|
|
| 417 |
async with self.conn.execute(
|
| 418 |
-
"SELECT
|
| 419 |
) as cursor:
|
| 420 |
-
|
| 421 |
|
| 422 |
-
if
|
| 423 |
thread_id = user_id + self.generate_thread_id()
|
| 424 |
await self.conn.execute(
|
| 425 |
"INSERT INTO userid_threadid (userId, threadId) VALUES (?, ?)",
|
|
@@ -427,12 +448,12 @@ class AIBackend:
|
|
| 427 |
)
|
| 428 |
await self.conn.commit()
|
| 429 |
else:
|
| 430 |
-
thread_id =
|
| 431 |
|
| 432 |
initial_state = {"messages": [HumanMessage(content=user_query)]}
|
| 433 |
config = {
|
| 434 |
"configurable": {"thread_id": thread_id},
|
| 435 |
-
"metadata":
|
| 436 |
-
"run_name":
|
| 437 |
}
|
| 438 |
return self.ai_only_stream(initial_state, config)
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import uuid
|
| 7 |
+
|
| 8 |
+
import aiosqlite
|
| 9 |
+
import pytz
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
|
| 13 |
+
from langchain_core.messages import (
|
| 14 |
+
AIMessage, AIMessageChunk, HumanMessage, RemoveMessage,
|
| 15 |
+
SystemMessage, ToolMessage,
|
| 16 |
+
)
|
| 17 |
+
from langchain_core.tools import tool
|
| 18 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 19 |
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
| 20 |
+
from langgraph.graph import END, START, StateGraph
|
| 21 |
+
from langgraph.graph.message import add_messages
|
| 22 |
from langgraph.prebuilt import ToolNode, tools_condition
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from twilio.rest import Client
|
| 24 |
+
from typing import Annotated, TypedDict
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
|
| 27 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 28 |
+
# STATE
|
| 29 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 30 |
class ChatState(TypedDict):
|
| 31 |
+
messages: Annotated[list, add_messages]
|
| 32 |
summary: str
|
| 33 |
|
| 34 |
+
|
| 35 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 36 |
+
# HELPERS
|
| 37 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 38 |
+
def get_db_path() -> str:
|
| 39 |
return os.path.join(os.path.dirname(__file__), "daa.db")
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def format_bd_number(num: str) -> str:
|
| 43 |
num = num.strip().replace(" ", "")
|
|
|
|
| 45 |
return "+88" + num
|
| 46 |
if num.startswith("8801"):
|
| 47 |
return "+" + num
|
| 48 |
+
return num
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def send_sms(to_number: str, message: str) -> None:
|
| 52 |
+
client = Client(os.getenv("TWILIO_ACCOUNT_SID"), os.getenv("TWILIO_AUTH_TOKEN"))
|
| 53 |
+
client.messages.create(
|
| 54 |
+
body=message,
|
| 55 |
+
from_=os.getenv("TWILIO_PHONE_NUMBER"),
|
| 56 |
+
to=to_number,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
|
| 60 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 61 |
+
# TOOLS
|
| 62 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 63 |
@tool
|
| 64 |
def get_bd_time() -> str:
|
| 65 |
+
"""Get current Bangladesh time (Asia/Dhaka) with weekday name."""
|
| 66 |
+
tz = pytz.timezone("Asia/Dhaka")
|
|
|
|
|
|
|
| 67 |
now = datetime.now(tz)
|
| 68 |
return now.strftime("%Y-%m-%d %H:%M:%S (%A, Bangladesh Time)")
|
| 69 |
|
| 70 |
+
|
| 71 |
@tool
|
| 72 |
+
async def search_doctor(
|
| 73 |
+
name: str = "",
|
| 74 |
+
category: str = "",
|
| 75 |
+
visiting_days: str = "",
|
| 76 |
+
) -> str:
|
| 77 |
"""
|
| 78 |
+
Search doctors by name, category, or visiting_days from the SQLite database.
|
| 79 |
+
Any combination of filters is supported (OR logic across fields).
|
| 80 |
"""
|
| 81 |
+
db_path = get_db_path()
|
| 82 |
+
query = "SELECT * FROM doctors WHERE 1=1"
|
| 83 |
+
params: list = []
|
| 84 |
+
conditions: list[str] = []
|
| 85 |
|
| 86 |
if name:
|
| 87 |
conditions.append("LOWER(doctor_name) LIKE ?")
|
| 88 |
params.append(f"%{name.lower()}%")
|
|
|
|
| 89 |
if category:
|
| 90 |
conditions.append("LOWER(category) LIKE ?")
|
| 91 |
params.append(f"%{category.lower()}%")
|
|
|
|
| 92 |
if visiting_days:
|
| 93 |
conditions.append("LOWER(visiting_days) LIKE ?")
|
| 94 |
params.append(f"%{visiting_days.lower()}%")
|
|
|
|
| 99 |
async with aiosqlite.connect(db_path) as db:
|
| 100 |
db.row_factory = aiosqlite.Row
|
| 101 |
cursor = await db.execute(query, params)
|
| 102 |
+
rows = await cursor.fetchall()
|
| 103 |
|
| 104 |
if not rows:
|
| 105 |
+
return json.dumps({"success": False, "message": "No doctors found.", "data": []})
|
| 106 |
+
|
| 107 |
+
return json.dumps({"success": True, "count": len(rows), "data": [dict(r) for r in rows]})
|
|
|
|
|
|
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
@tool
|
| 111 |
async def search_appointment_by_phone(patient_num: str) -> str:
|
| 112 |
+
"""Search all appointments using the patient's phone number."""
|
| 113 |
+
db_path = get_db_path()
|
|
|
|
|
|
|
| 114 |
patient_num = format_bd_number(patient_num)
|
| 115 |
|
| 116 |
async with aiosqlite.connect(db_path) as db:
|
| 117 |
db.row_factory = aiosqlite.Row
|
| 118 |
+
cursor = await db.execute(
|
| 119 |
+
"SELECT * FROM patients WHERE patient_num = ? ORDER BY visiting_date ASC",
|
| 120 |
+
(patient_num,),
|
| 121 |
+
)
|
|
|
|
|
|
|
|
|
|
| 122 |
rows = await cursor.fetchall()
|
| 123 |
|
| 124 |
if not rows:
|
| 125 |
return json.dumps({
|
| 126 |
"success": False,
|
| 127 |
"message": "No appointments found for this phone number.",
|
| 128 |
+
"data": [],
|
| 129 |
})
|
| 130 |
+
return json.dumps({"success": True, "count": len(rows), "data": [dict(r) for r in rows]})
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
@tool
|
| 134 |
+
async def book_appointment(
|
| 135 |
+
doctor_id: int,
|
| 136 |
+
patient_name: str,
|
| 137 |
+
patient_age: str,
|
| 138 |
+
patient_num: str,
|
| 139 |
+
visiting_date: str,
|
| 140 |
+
) -> str:
|
| 141 |
"""
|
| 142 |
Book a doctor appointment and save it to the patients table.
|
| 143 |
|
| 144 |
Args:
|
| 145 |
+
doctor_id: Doctor's ID from search_doctor results.
|
| 146 |
+
patient_name: Full name of the patient.
|
| 147 |
+
patient_age: Age of the patient (e.g. "32").
|
| 148 |
+
patient_num: Contact phone number of the patient.
|
| 149 |
visiting_date: Date of visit in YYYY-MM-DD format (e.g. 2025-06-15).
|
|
|
|
|
|
|
| 150 |
"""
|
| 151 |
+
db_path = get_db_path()
|
| 152 |
+
patient_num = format_bd_number(patient_num)
|
| 153 |
|
| 154 |
async with aiosqlite.connect(db_path) as db:
|
| 155 |
db.row_factory = aiosqlite.Row
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
cursor = await db.execute("SELECT * FROM doctors WHERE id = ?", (doctor_id,))
|
| 158 |
doctor = await cursor.fetchone()
|
| 159 |
if not doctor:
|
| 160 |
return f"No doctor found with ID {doctor_id}. Please search for a doctor first."
|
| 161 |
|
| 162 |
+
doctor_data = dict(doctor)
|
| 163 |
+
doctor_name = doctor_data.get("doctor_name", "Unknown")
|
| 164 |
+
doctor_category = doctor_data.get("category", "Unknown")
|
| 165 |
|
|
|
|
| 166 |
cursor = await db.execute(
|
| 167 |
"""SELECT id FROM patients
|
| 168 |
+
WHERE doctor_name = ? AND visiting_date = ? AND patient_num = ?""",
|
| 169 |
(doctor_name, visiting_date, patient_num),
|
| 170 |
)
|
| 171 |
+
if await cursor.fetchone():
|
|
|
|
| 172 |
return (
|
| 173 |
f"A booking for {patient_name} with Dr. {doctor_name} "
|
| 174 |
f"on {visiting_date} already exists."
|
| 175 |
)
|
| 176 |
|
| 177 |
+
await db.execute(
|
| 178 |
+
"""INSERT INTO patients
|
| 179 |
+
(doctor_name, doctor_category, patient_name, patient_age, patient_num, visiting_date)
|
| 180 |
+
VALUES (?, ?, ?, ?, ?, ?)""",
|
| 181 |
(doctor_name, doctor_category, patient_name, patient_age, patient_num, visiting_date),
|
| 182 |
)
|
| 183 |
await db.commit()
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
return (
|
| 186 |
f"✅ Appointment Booked!\n"
|
| 187 |
f"━━━━━━━━━━━━━━━━━━━━━━\n"
|
|
|
|
| 192 |
f"Contact : {patient_num}\n"
|
| 193 |
f"━━━━━━━━━━━━━━━━━━━━━━\n"
|
| 194 |
f"Please arrive 10 minutes early."
|
|
|
|
| 195 |
)
|
| 196 |
|
| 197 |
+
|
| 198 |
+
@tool
|
| 199 |
async def delete_appointment(patient_num: str, doctor_name: str) -> str:
|
| 200 |
+
"""Delete an appointment using the patient's phone number and doctor name."""
|
| 201 |
+
db_path = get_db_path()
|
|
|
|
|
|
|
|
|
|
| 202 |
patient_num = format_bd_number(patient_num)
|
| 203 |
|
| 204 |
async with aiosqlite.connect(db_path) as db:
|
| 205 |
db.row_factory = aiosqlite.Row
|
| 206 |
|
| 207 |
+
cursor = await db.execute(
|
| 208 |
+
"""SELECT * FROM patients
|
| 209 |
+
WHERE patient_num = ? AND LOWER(doctor_name) = LOWER(?)""",
|
| 210 |
+
(patient_num, doctor_name),
|
| 211 |
+
)
|
| 212 |
+
if not await cursor.fetchone():
|
| 213 |
+
return json.dumps({"success": False, "message": "No matching appointment found."})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
await db.execute(
|
| 216 |
+
"""DELETE FROM patients
|
| 217 |
+
WHERE patient_num = ? AND LOWER(doctor_name) = LOWER(?)""",
|
| 218 |
+
(patient_num, doctor_name),
|
| 219 |
+
)
|
| 220 |
await db.commit()
|
| 221 |
|
| 222 |
return json.dumps({
|
| 223 |
"success": True,
|
| 224 |
+
"message": f"Appointment with Dr. {doctor_name} deleted successfully.",
|
| 225 |
})
|
| 226 |
|
| 227 |
+
|
| 228 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 229 |
+
# SYSTEM PROMPT
|
| 230 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 231 |
+
BASE_SYSTEM = (
|
| 232 |
+
"You are a helpful Bangla voice assistant for a doctor appointment system.\n"
|
| 233 |
+
"Rules:\n"
|
| 234 |
+
"- Always respond in Bangla (বাংলা).\n"
|
| 235 |
+
"- Keep sentences short and natural for text-to-speech playback.\n"
|
| 236 |
+
"- Avoid markdown, bullet points, or long lists in voice responses.\n"
|
| 237 |
+
"- Use tools when needed to search doctors or manage appointments.\n"
|
| 238 |
+
"- Be polite, concise, and clear.\n"
|
| 239 |
+
"- Do not use English unless a proper noun requires it.\n"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
SUMMARY_SYSTEM = (
|
| 243 |
+
BASE_SYSTEM
|
| 244 |
+
+ "\nYou also have a condensed memory of previous conversations:\n\n"
|
| 245 |
+
"{summary}\n\n"
|
| 246 |
+
"Use this memory for continuity. Do not repeat it unless asked."
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 251 |
+
# AGENT
|
| 252 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 253 |
class AIBackend:
|
| 254 |
+
|
| 255 |
+
def __init__(self) -> None:
|
| 256 |
load_dotenv()
|
| 257 |
+
os.environ.setdefault("LANGCHAIN_PROJECT", "Doctor Appointment Automation")
|
| 258 |
+
|
| 259 |
+
self.llm = ChatGoogleGenerativeAI(
|
| 260 |
+
model="gemini-2.0-flash",
|
| 261 |
+
temperature=0.3,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
self.tools = [
|
| 265 |
+
search_doctor,
|
| 266 |
+
book_appointment,
|
| 267 |
+
get_bd_time,
|
| 268 |
+
search_appointment_by_phone,
|
| 269 |
+
delete_appointment,
|
| 270 |
+
]
|
| 271 |
+
self.tool_node = ToolNode(self.tools)
|
| 272 |
self.llm_with_tools = self.llm.bind_tools(self.tools)
|
| 273 |
|
| 274 |
+
# ── Setup ──────────────────────────────────────────────────────────────────
|
| 275 |
+
async def async_setup(self) -> None:
|
| 276 |
+
db_path = get_db_path()
|
| 277 |
+
self.conn = await aiosqlite.connect(db_path)
|
| 278 |
self.checkpointer = AsyncSqliteSaver(self.conn)
|
| 279 |
+
await self._create_tables()
|
| 280 |
+
self.graph = self._build_graph()
|
| 281 |
self.summary_graph = self._build_summary_graph()
|
| 282 |
|
| 283 |
+
async def _create_tables(self) -> None:
|
| 284 |
await self.conn.execute("""
|
| 285 |
CREATE TABLE IF NOT EXISTS userid_threadid (
|
| 286 |
+
userId TEXT UNIQUE NOT NULL,
|
| 287 |
threadId TEXT UNIQUE NOT NULL
|
| 288 |
)
|
| 289 |
+
""")
|
| 290 |
+
await self.conn.execute("""
|
| 291 |
+
CREATE TABLE IF NOT EXISTS doctors (
|
| 292 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 293 |
+
doctor_name TEXT NOT NULL,
|
| 294 |
+
category TEXT NOT NULL,
|
| 295 |
+
visiting_days TEXT NOT NULL,
|
| 296 |
+
chamber TEXT,
|
| 297 |
+
fee TEXT
|
| 298 |
+
)
|
| 299 |
+
""")
|
| 300 |
+
await self.conn.execute("""
|
| 301 |
+
CREATE TABLE IF NOT EXISTS patients (
|
| 302 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 303 |
+
doctor_name TEXT NOT NULL,
|
| 304 |
+
doctor_category TEXT,
|
| 305 |
+
patient_name TEXT NOT NULL,
|
| 306 |
+
patient_age TEXT,
|
| 307 |
+
patient_num TEXT NOT NULL,
|
| 308 |
+
visiting_date TEXT NOT NULL
|
| 309 |
+
)
|
| 310 |
+
""")
|
| 311 |
await self.conn.commit()
|
| 312 |
|
| 313 |
+
# ── Summarise node ─────────────────────────────────────────────────────────
|
| 314 |
async def summarize_conversation(self, state: ChatState):
|
| 315 |
+
existing = state.get("summary", "")
|
| 316 |
messages = state["messages"]
|
| 317 |
+
|
| 318 |
+
if existing:
|
| 319 |
+
prompt = (
|
| 320 |
+
f"Existing summary:\n{existing}\n\n"
|
| 321 |
+
"Update the summary with the new messages above. "
|
| 322 |
+
"Keep it concise, bullet-pointed, and information-dense. "
|
| 323 |
+
"Preserve unresolved issues and ongoing tasks."
|
| 324 |
+
)
|
| 325 |
+
else:
|
| 326 |
+
prompt = (
|
| 327 |
+
"Summarise this conversation. "
|
| 328 |
+
"Capture goals, decisions, preferences, and unresolved questions. "
|
| 329 |
+
"Be concise and use bullet points."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
response = await self.llm.ainvoke(messages + [HumanMessage(content=prompt)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
return {
|
| 334 |
"summary": response.content,
|
| 335 |
"messages": [RemoveMessage(id=m.id) for m in messages[:-2]],
|
| 336 |
}
|
| 337 |
|
| 338 |
+
async def should_summarize(self, state: ChatState) -> str:
|
| 339 |
+
return "summarize_node" if len(state["messages"]) > 10 else "chat_node"
|
|
|
|
|
|
|
| 340 |
|
| 341 |
+
# ── Chat node — streaming version ──────────────────────────────────────────
|
| 342 |
async def chat_node(self, state: ChatState):
|
| 343 |
+
"""
|
| 344 |
+
Uses astream() instead of ainvoke() so that LangGraph's
|
| 345 |
+
stream_mode='messages' can relay individual tokens to the caller
|
| 346 |
+
as they arrive from Gemini, rather than waiting for the full
|
| 347 |
+
response to complete before yielding anything.
|
| 348 |
+
|
| 349 |
+
The streamed chunks are merged into a single AIMessage for the
|
| 350 |
+
graph state so checkpointing and tool detection work unchanged.
|
| 351 |
+
"""
|
| 352 |
+
summary = state.get("summary", "")
|
| 353 |
messages = state["messages"]
|
| 354 |
|
| 355 |
+
print("#" * 50)
|
| 356 |
print(">>>>>>>>>> CHAT NODE START <<<<<<<<<<")
|
| 357 |
+
print(f"[SUMMARY]: {summary[:120] if summary else 'None'}")
|
| 358 |
+
for m in messages:
|
| 359 |
+
print(f" [{m.__class__.__name__}]: {str(m.content)[:160]}")
|
| 360 |
+
print("#" * 50)
|
| 361 |
+
|
| 362 |
+
sys_content = SUMMARY_SYSTEM.format(summary=summary) if summary else BASE_SYSTEM
|
| 363 |
+
full_messages = [SystemMessage(content=sys_content)] + list(messages)
|
| 364 |
+
|
| 365 |
+
# Stream tokens from Gemini — LangGraph relays these via
|
| 366 |
+
# stream_mode="messages" before the node returns its state update.
|
| 367 |
+
collected: list[AIMessageChunk] = []
|
| 368 |
+
async for chunk in self.llm_with_tools.astream(full_messages):
|
| 369 |
+
collected.append(chunk)
|
| 370 |
+
|
| 371 |
+
# Merge chunks into a single AIMessage for the state
|
| 372 |
+
if not collected:
|
| 373 |
+
response = AIMessage(content="")
|
| 374 |
else:
|
| 375 |
+
# LangChain chunk addition merges content + tool_calls correctly
|
| 376 |
+
response = collected[0]
|
| 377 |
+
for c in collected[1:]:
|
| 378 |
+
response = response + c
|
| 379 |
|
| 380 |
+
print(f"[AI]: {str(response.content)[:200]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
print(">>>>>>>>>> CHAT NODE END <<<<<<<<<<")
|
|
|
|
| 382 |
return {"messages": [response]}
|
| 383 |
|
| 384 |
+
# ── Graph ──────────────────────────────────────────────────────────────────
|
| 385 |
def _build_graph(self):
|
| 386 |
g = StateGraph(ChatState)
|
| 387 |
g.add_node("chat_node", self.chat_node)
|
| 388 |
+
g.add_node("tools", self.tool_node)
|
|
|
|
| 389 |
g.add_edge(START, "chat_node")
|
| 390 |
g.add_conditional_edges("chat_node", tools_condition)
|
| 391 |
g.add_edge("tools", "chat_node")
|
|
|
|
| 392 |
return g.compile(checkpointer=self.checkpointer)
|
| 393 |
|
| 394 |
def _build_summary_graph(self):
|
|
|
|
| 398 |
g.add_edge("summarize_node", END)
|
| 399 |
return g.compile(checkpointer=self.checkpointer)
|
| 400 |
|
| 401 |
+
# ── Streaming ──────────────────────────────────────────────────────────────
|
| 402 |
async def ai_only_stream(self, initial_state: dict, config: dict):
|
| 403 |
+
"""
|
| 404 |
+
Async generator — yields AI text tokens as they arrive from Gemini.
|
| 405 |
+
|
| 406 |
+
Because chat_node now uses astream() internally, LangGraph's
|
| 407 |
+
stream_mode='messages' receives genuine token chunks from the model
|
| 408 |
+
and re-emits them here — no more full-response buffering.
|
| 409 |
+
"""
|
| 410 |
+
async for chunk, _meta in self.graph.astream(
|
| 411 |
+
initial_state, config=config, stream_mode="messages"
|
| 412 |
+
):
|
| 413 |
+
if isinstance(chunk, AIMessage) and chunk.content:
|
| 414 |
+
yield chunk.content
|
| 415 |
+
|
| 416 |
+
# Auto-summarise in background when history grows long
|
| 417 |
+
current = await self.graph.aget_state(config)
|
| 418 |
+
if len(current.values.get("messages", [])) > 10:
|
| 419 |
asyncio.create_task(
|
| 420 |
+
self.summary_graph.ainvoke(current.values, config=config)
|
| 421 |
)
|
| 422 |
+
print("@" * 20, "Summarisation triggered", "@" * 20)
|
| 423 |
|
| 424 |
+
# ── Thread management ──────────────────────────────────────────────────────
|
| 425 |
@staticmethod
|
| 426 |
def generate_thread_id() -> str:
|
| 427 |
return str(uuid.uuid4())
|
| 428 |
|
| 429 |
+
async def retrieve_all_threads(self) -> list[str]:
|
| 430 |
+
threads: set[str] = set()
|
| 431 |
+
async for cp in self.checkpointer.alist(None):
|
| 432 |
+
threads.add(cp.config["configurable"]["thread_id"])
|
| 433 |
+
return list(threads)
|
|
|
|
| 434 |
|
| 435 |
+
# ── Public entry point ─────────────────────────────────────────────────────
|
| 436 |
async def main(self, user_id: str, user_query: str):
|
| 437 |
+
"""Return an async generator of AI text tokens."""
|
| 438 |
async with self.conn.execute(
|
| 439 |
+
"SELECT threadId FROM userid_threadid WHERE userId = ?", (user_id,)
|
| 440 |
) as cursor:
|
| 441 |
+
row = await cursor.fetchone()
|
| 442 |
|
| 443 |
+
if row is None:
|
| 444 |
thread_id = user_id + self.generate_thread_id()
|
| 445 |
await self.conn.execute(
|
| 446 |
"INSERT INTO userid_threadid (userId, threadId) VALUES (?, ?)",
|
|
|
|
| 448 |
)
|
| 449 |
await self.conn.commit()
|
| 450 |
else:
|
| 451 |
+
thread_id = row[0]
|
| 452 |
|
| 453 |
initial_state = {"messages": [HumanMessage(content=user_query)]}
|
| 454 |
config = {
|
| 455 |
"configurable": {"thread_id": thread_id},
|
| 456 |
+
"metadata": {"thread_id": thread_id},
|
| 457 |
+
"run_name": "chat_turn",
|
| 458 |
}
|
| 459 |
return self.ai_only_stream(initial_state, config)
|
frontend/index.html
CHANGED
|
@@ -45,3 +45,4 @@
|
|
| 45 |
<script src="script.js"></script>
|
| 46 |
</body>
|
| 47 |
</html>
|
|
|
|
|
|
| 45 |
<script src="script.js"></script>
|
| 46 |
</body>
|
| 47 |
</html>
|
| 48 |
+
|
frontend/script.js
CHANGED
|
@@ -1,22 +1,3 @@
|
|
| 1 |
-
/* ─────────────────────────────────────────────────────────────────────────────
|
| 2 |
-
script.js — Voice + text chat client
|
| 3 |
-
|
| 4 |
-
Fixes applied
|
| 5 |
-
─────────────
|
| 6 |
-
1. DOUBLE-SEND BUG: silenceTimer is now explicitly cleared whenever
|
| 7 |
-
isProcessing is set to true, so a timer that was already ticking
|
| 8 |
-
can't fire a second stopRecorder() call.
|
| 9 |
-
2. TTS INTERRUPT / BARGE-IN: stopAllAudio() cancels the current
|
| 10 |
-
HTMLAudioElement and sends {"type":"cancel"} to the server so the
|
| 11 |
-
TTS pipeline also aborts server-side.
|
| 12 |
-
3. MARKDOWN RENDERING: AI bubble uses marked.parse() instead of
|
| 13 |
-
textContent so Bangla markdown (bold, lists, headings) renders
|
| 14 |
-
correctly in the chat.
|
| 15 |
-
4. VAD barge-in path: if the user starts speaking while TTS is playing
|
| 16 |
-
the audio stops immediately, isProcessing resets, and the new
|
| 17 |
-
utterance is captured normally.
|
| 18 |
-
───────────────────────────────────────────────────────────────────────────── */
|
| 19 |
-
|
| 20 |
const chatBox = document.getElementById('chat-box');
|
| 21 |
const sendBtn = document.getElementById('send-btn');
|
| 22 |
const textInput = document.getElementById('text-input');
|
|
@@ -24,12 +5,10 @@ const micBtn = document.getElementById('mic-btn');
|
|
| 24 |
|
| 25 |
const userId = 'walid';
|
| 26 |
|
| 27 |
-
// ── WebSockets ────────────────────────────────────────────────────────────────
|
| 28 |
const chatSocket = new WebSocket('ws://127.0.0.1:8679/ws/chat');
|
| 29 |
const voiceSocket = new WebSocket('ws://127.0.0.1:8679/ws/voice');
|
| 30 |
voiceSocket.binaryType = 'arraybuffer';
|
| 31 |
|
| 32 |
-
// ── State ─────────────────────────────────────────────────────────────────────
|
| 33 |
let micStream = null;
|
| 34 |
let audioContext = null;
|
| 35 |
let analyser = null;
|
|
@@ -39,18 +18,98 @@ let isListening = false;
|
|
| 39 |
let isSpeaking = false;
|
| 40 |
let silenceTimer = null;
|
| 41 |
let vadInterval = null;
|
| 42 |
-
let isProcessing = false;
|
| 43 |
|
| 44 |
let currentAIMessage = null;
|
| 45 |
-
let
|
| 46 |
-
let playbackChain = Promise.resolve();
|
| 47 |
|
| 48 |
-
|
| 49 |
-
const
|
| 50 |
-
const SILENCE_TIMEOUT_MS = 3000; // ms of silence before sending utterance
|
| 51 |
const VAD_POLL_MS = 100;
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
sendBtn.onclick = sendTextMessage;
|
| 55 |
textInput.addEventListener('keydown', (e) => {
|
| 56 |
if (e.key === 'Enter') sendTextMessage();
|
|
@@ -77,7 +136,6 @@ chatSocket.onmessage = (e) => {
|
|
| 77 |
chatSocket.onerror = (e) => console.error('Chat WS error:', e);
|
| 78 |
chatSocket.onclose = () => console.log('Chat WS closed');
|
| 79 |
|
| 80 |
-
// ── Voice WebSocket events ────────────────────────────────────────────────────
|
| 81 |
voiceSocket.onopen = () => console.log('[WS] Voice connected');
|
| 82 |
voiceSocket.onclose = () => {
|
| 83 |
console.log('[WS] Voice closed');
|
|
@@ -86,7 +144,6 @@ voiceSocket.onclose = () => {
|
|
| 86 |
voiceSocket.onerror = (e) => console.error('[WS] Voice error:', e);
|
| 87 |
|
| 88 |
voiceSocket.onmessage = (event) => {
|
| 89 |
-
// Binary → audio playback
|
| 90 |
if (event.data instanceof ArrayBuffer) {
|
| 91 |
enqueueAudio(event.data);
|
| 92 |
return;
|
|
@@ -106,25 +163,22 @@ voiceSocket.onmessage = (event) => {
|
|
| 106 |
break;
|
| 107 |
|
| 108 |
case 'llm_token':
|
| 109 |
-
// FIX: stream tokens into a div; final markdown render happens on 'end'
|
| 110 |
if (!currentAIMessage) {
|
| 111 |
currentAIMessage = appendMessage('', 'ai');
|
| 112 |
currentAIMessage._raw = '';
|
| 113 |
}
|
| 114 |
currentAIMessage._raw += msg.token;
|
| 115 |
-
// Live preview: render markdown progressively
|
| 116 |
currentAIMessage.innerHTML = marked.parse(currentAIMessage._raw);
|
| 117 |
chatBox.scrollTop = chatBox.scrollHeight;
|
| 118 |
break;
|
| 119 |
|
| 120 |
case 'end':
|
| 121 |
-
// Ensure final markdown render
|
| 122 |
if (currentAIMessage && currentAIMessage._raw) {
|
| 123 |
currentAIMessage.innerHTML = marked.parse(currentAIMessage._raw);
|
| 124 |
}
|
| 125 |
currentAIMessage = null;
|
| 126 |
-
|
| 127 |
-
|
| 128 |
break;
|
| 129 |
|
| 130 |
case 'error':
|
|
@@ -137,70 +191,18 @@ voiceSocket.onmessage = (event) => {
|
|
| 137 |
break;
|
| 138 |
|
| 139 |
default:
|
| 140 |
-
console.log('[WS] Unknown msg:', msg.type);
|
| 141 |
}
|
| 142 |
};
|
| 143 |
|
| 144 |
-
// ── Audio playback ─────────────────────────────────────────────────────────────
|
| 145 |
-
function enqueueAudio(buffer) {
|
| 146 |
-
playbackChain = playbackChain.then(() => playBuffer(buffer));
|
| 147 |
-
}
|
| 148 |
-
|
| 149 |
-
function playBuffer(buffer) {
|
| 150 |
-
return new Promise((resolve) => {
|
| 151 |
-
if (isProcessing === false) {
|
| 152 |
-
resolve();
|
| 153 |
-
return;
|
| 154 |
-
} // cancelled mid-chain
|
| 155 |
-
|
| 156 |
-
const blob = new Blob([buffer], { type: 'audio/mpeg' });
|
| 157 |
-
const url = URL.createObjectURL(blob);
|
| 158 |
-
const audio = new Audio(url);
|
| 159 |
-
currentAudio = audio;
|
| 160 |
-
|
| 161 |
-
const done = () => {
|
| 162 |
-
URL.revokeObjectURL(url);
|
| 163 |
-
currentAudio = null;
|
| 164 |
-
resolve();
|
| 165 |
-
};
|
| 166 |
-
audio.onended = done;
|
| 167 |
-
audio.onerror = () => {
|
| 168 |
-
console.warn('[AUDIO] playback error');
|
| 169 |
-
done();
|
| 170 |
-
};
|
| 171 |
-
audio.play().catch(() => done());
|
| 172 |
-
});
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
/**
|
| 176 |
-
* Stop all queued and current audio immediately.
|
| 177 |
-
* Also sends a cancel signal to the server so TTS generation stops.
|
| 178 |
-
*/
|
| 179 |
-
function stopAllAudio() {
|
| 180 |
-
// Replace the chain with an already-resolved promise so queued buffers
|
| 181 |
-
// that haven't started yet are silently dropped.
|
| 182 |
-
playbackChain = Promise.resolve();
|
| 183 |
-
|
| 184 |
-
if (currentAudio) {
|
| 185 |
-
currentAudio.pause();
|
| 186 |
-
currentAudio.src = '';
|
| 187 |
-
currentAudio = null;
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
// Tell server to abort TTS pipeline
|
| 191 |
-
if (voiceSocket.readyState === WebSocket.OPEN) {
|
| 192 |
-
voiceSocket.send(JSON.stringify({ type: 'cancel' }));
|
| 193 |
-
}
|
| 194 |
-
}
|
| 195 |
-
|
| 196 |
-
// ── Mic button ────────────────────────────────────────────────────────────────
|
| 197 |
micBtn.onclick = async () => {
|
| 198 |
if (!isListening) await startListening();
|
| 199 |
else stopListening();
|
| 200 |
};
|
| 201 |
|
| 202 |
-
// ── Start continuous listening with VAD ───────────────────────────────────────
|
| 203 |
async function startListening() {
|
|
|
|
|
|
|
| 204 |
try {
|
| 205 |
micStream = await navigator.mediaDevices.getUserMedia({
|
| 206 |
audio: {
|
|
@@ -228,14 +230,12 @@ async function startListening() {
|
|
| 228 |
vadInterval = setInterval(vadTick, VAD_POLL_MS);
|
| 229 |
}
|
| 230 |
|
| 231 |
-
// ── Stop everything ───────────────────────────────────────────────────────────
|
| 232 |
function stopListening() {
|
| 233 |
clearInterval(vadInterval);
|
| 234 |
clearTimeout(silenceTimer);
|
| 235 |
vadInterval = silenceTimer = null;
|
| 236 |
|
| 237 |
-
if (isSpeaking) stopRecorder(true);
|
| 238 |
-
|
| 239 |
stopAllAudio();
|
| 240 |
|
| 241 |
micStream?.getTracks().forEach((t) => t.stop());
|
|
@@ -246,7 +246,6 @@ function stopListening() {
|
|
| 246 |
setMicStatus('off');
|
| 247 |
}
|
| 248 |
|
| 249 |
-
// ── VAD polling ───────────────────────────────────────────────────────────────
|
| 250 |
function vadTick() {
|
| 251 |
if (!analyser) return;
|
| 252 |
|
|
@@ -258,19 +257,18 @@ function vadTick() {
|
|
| 258 |
const speaking = db > SILENCE_THRESHOLD_DB;
|
| 259 |
|
| 260 |
if (speaking) {
|
| 261 |
-
// FIX: barge-in — user started talking while TTS is playing
|
| 262 |
if (isProcessing) {
|
| 263 |
-
console.log('[VAD] Barge-in
|
| 264 |
stopAllAudio();
|
| 265 |
isProcessing = false;
|
| 266 |
}
|
| 267 |
|
| 268 |
-
// FIX: clear any pending silence timer so it can't double-fire
|
| 269 |
clearTimeout(silenceTimer);
|
| 270 |
silenceTimer = null;
|
| 271 |
|
| 272 |
if (!isSpeaking) {
|
| 273 |
isSpeaking = true;
|
|
|
|
| 274 |
startRecorder();
|
| 275 |
setMicStatus('recording');
|
| 276 |
}
|
|
@@ -280,9 +278,9 @@ function vadTick() {
|
|
| 280 |
silenceTimer = null;
|
| 281 |
isSpeaking = false;
|
| 282 |
|
| 283 |
-
// FIX: set isProcessing *before* stopping the recorder so that
|
| 284 |
-
// if vadTick fires again during onstop it sees the flag and skips.
|
| 285 |
isProcessing = true;
|
|
|
|
|
|
|
| 286 |
stopRecorder(false);
|
| 287 |
setMicStatus('processing');
|
| 288 |
}, SILENCE_TIMEOUT_MS);
|
|
@@ -290,7 +288,6 @@ function vadTick() {
|
|
| 290 |
}
|
| 291 |
}
|
| 292 |
|
| 293 |
-
// ── Recorder ──────────────────────────────────────────────────────────────────
|
| 294 |
function startRecorder() {
|
| 295 |
if (!micStream) return;
|
| 296 |
audioChunks = [];
|
|
@@ -335,7 +332,6 @@ function stopRecorder(discard = false) {
|
|
| 335 |
mediaRecorder = null;
|
| 336 |
}
|
| 337 |
|
| 338 |
-
// ── UI helpers ────────────────────────────────────────────────────────────────
|
| 339 |
function setMicStatus(state) {
|
| 340 |
const labels = {
|
| 341 |
off: '🎤 Start Voice',
|
|
@@ -352,7 +348,6 @@ function appendMessage(text, sender) {
|
|
| 352 |
div.className = `message ${sender}`;
|
| 353 |
|
| 354 |
if (sender === 'ai' && typeof marked !== 'undefined') {
|
| 355 |
-
// FIX: render Bangla markdown (bold, lists, headings) properly
|
| 356 |
div.innerHTML = marked.parse(text);
|
| 357 |
} else {
|
| 358 |
div.textContent = text;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
const chatBox = document.getElementById('chat-box');
|
| 2 |
const sendBtn = document.getElementById('send-btn');
|
| 3 |
const textInput = document.getElementById('text-input');
|
|
|
|
| 5 |
|
| 6 |
const userId = 'walid';
|
| 7 |
|
|
|
|
| 8 |
const chatSocket = new WebSocket('ws://127.0.0.1:8679/ws/chat');
|
| 9 |
const voiceSocket = new WebSocket('ws://127.0.0.1:8679/ws/voice');
|
| 10 |
voiceSocket.binaryType = 'arraybuffer';
|
| 11 |
|
|
|
|
| 12 |
let micStream = null;
|
| 13 |
let audioContext = null;
|
| 14 |
let analyser = null;
|
|
|
|
| 18 |
let isSpeaking = false;
|
| 19 |
let silenceTimer = null;
|
| 20 |
let vadInterval = null;
|
| 21 |
+
let isProcessing = false;
|
| 22 |
|
| 23 |
let currentAIMessage = null;
|
| 24 |
+
let _playbackCancelled = false;
|
|
|
|
| 25 |
|
| 26 |
+
const SILENCE_THRESHOLD_DB = -45;
|
| 27 |
+
const SILENCE_TIMEOUT_MS = 1200;
|
|
|
|
| 28 |
const VAD_POLL_MS = 100;
|
| 29 |
|
| 30 |
+
let _playCtx = null;
|
| 31 |
+
let _schedEndTime = 0;
|
| 32 |
+
let _endTimer = null;
|
| 33 |
+
|
| 34 |
+
function _getPlayCtx() {
|
| 35 |
+
if (!_playCtx || _playCtx.state === 'closed') {
|
| 36 |
+
_playCtx = new (window.AudioContext || window.webkitAudioContext)();
|
| 37 |
+
_schedEndTime = 0;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
if (_playCtx.state === 'suspended') _playCtx.resume();
|
| 41 |
+
return _playCtx;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
async function enqueueAudio(buffer) {
|
| 45 |
+
if (_playbackCancelled) return;
|
| 46 |
+
|
| 47 |
+
const ctx = _getPlayCtx();
|
| 48 |
+
let decoded;
|
| 49 |
+
try {
|
| 50 |
+
decoded = await ctx.decodeAudioData(buffer.slice(0));
|
| 51 |
+
} catch (err) {
|
| 52 |
+
console.warn('[AUDIO] decode error:', err);
|
| 53 |
+
return;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
if (_playbackCancelled) return;
|
| 57 |
+
|
| 58 |
+
const src = ctx.createBufferSource();
|
| 59 |
+
src.buffer = decoded;
|
| 60 |
+
src.connect(ctx.destination);
|
| 61 |
+
|
| 62 |
+
const now = ctx.currentTime;
|
| 63 |
+
const startAt = Math.max(now + 0.02, _schedEndTime);
|
| 64 |
+
src.start(startAt);
|
| 65 |
+
_schedEndTime = startAt + decoded.duration;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
/**
|
| 69 |
+
* Called once the server sends `{type:"end"}`.
|
| 70 |
+
* We know all audio is enqueued; schedule the "processing done" callback
|
| 71 |
+
* to fire when the last chunk finishes playing.
|
| 72 |
+
*/
|
| 73 |
+
function _schedulePlaybackEnd() {
|
| 74 |
+
clearTimeout(_endTimer);
|
| 75 |
+
|
| 76 |
+
const ctx = _playCtx;
|
| 77 |
+
if (!ctx || ctx.state === 'closed') {
|
| 78 |
+
_onPlaybackFinished();
|
| 79 |
+
return;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
const remaining = Math.max(0, (_schedEndTime - ctx.currentTime) * 1000) + 120;
|
| 83 |
+
_endTimer = setTimeout(() => {
|
| 84 |
+
if (!_playbackCancelled) _onPlaybackFinished();
|
| 85 |
+
}, remaining);
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
function _onPlaybackFinished() {
|
| 89 |
+
isProcessing = false;
|
| 90 |
+
if (isListening) setMicStatus('listening');
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
/**
|
| 94 |
+
* Stop all queued and currently-playing audio immediately.
|
| 95 |
+
* Closes the AudioContext so future-scheduled nodes are silenced too.
|
| 96 |
+
*/
|
| 97 |
+
function stopAllAudio() {
|
| 98 |
+
_playbackCancelled = true;
|
| 99 |
+
clearTimeout(_endTimer);
|
| 100 |
+
_endTimer = null;
|
| 101 |
+
|
| 102 |
+
if (_playCtx && _playCtx.state !== 'closed') {
|
| 103 |
+
_playCtx.close().catch(() => {});
|
| 104 |
+
}
|
| 105 |
+
_playCtx = null;
|
| 106 |
+
_schedEndTime = 0;
|
| 107 |
+
|
| 108 |
+
if (voiceSocket.readyState === WebSocket.OPEN) {
|
| 109 |
+
voiceSocket.send(JSON.stringify({ type: 'cancel' }));
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
sendBtn.onclick = sendTextMessage;
|
| 114 |
textInput.addEventListener('keydown', (e) => {
|
| 115 |
if (e.key === 'Enter') sendTextMessage();
|
|
|
|
| 136 |
chatSocket.onerror = (e) => console.error('Chat WS error:', e);
|
| 137 |
chatSocket.onclose = () => console.log('Chat WS closed');
|
| 138 |
|
|
|
|
| 139 |
voiceSocket.onopen = () => console.log('[WS] Voice connected');
|
| 140 |
voiceSocket.onclose = () => {
|
| 141 |
console.log('[WS] Voice closed');
|
|
|
|
| 144 |
voiceSocket.onerror = (e) => console.error('[WS] Voice error:', e);
|
| 145 |
|
| 146 |
voiceSocket.onmessage = (event) => {
|
|
|
|
| 147 |
if (event.data instanceof ArrayBuffer) {
|
| 148 |
enqueueAudio(event.data);
|
| 149 |
return;
|
|
|
|
| 163 |
break;
|
| 164 |
|
| 165 |
case 'llm_token':
|
|
|
|
| 166 |
if (!currentAIMessage) {
|
| 167 |
currentAIMessage = appendMessage('', 'ai');
|
| 168 |
currentAIMessage._raw = '';
|
| 169 |
}
|
| 170 |
currentAIMessage._raw += msg.token;
|
|
|
|
| 171 |
currentAIMessage.innerHTML = marked.parse(currentAIMessage._raw);
|
| 172 |
chatBox.scrollTop = chatBox.scrollHeight;
|
| 173 |
break;
|
| 174 |
|
| 175 |
case 'end':
|
|
|
|
| 176 |
if (currentAIMessage && currentAIMessage._raw) {
|
| 177 |
currentAIMessage.innerHTML = marked.parse(currentAIMessage._raw);
|
| 178 |
}
|
| 179 |
currentAIMessage = null;
|
| 180 |
+
|
| 181 |
+
_schedulePlaybackEnd();
|
| 182 |
break;
|
| 183 |
|
| 184 |
case 'error':
|
|
|
|
| 191 |
break;
|
| 192 |
|
| 193 |
default:
|
| 194 |
+
console.log('[WS] Unknown msg type:', msg.type);
|
| 195 |
}
|
| 196 |
};
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
micBtn.onclick = async () => {
|
| 199 |
if (!isListening) await startListening();
|
| 200 |
else stopListening();
|
| 201 |
};
|
| 202 |
|
|
|
|
| 203 |
async function startListening() {
|
| 204 |
+
_getPlayCtx();
|
| 205 |
+
|
| 206 |
try {
|
| 207 |
micStream = await navigator.mediaDevices.getUserMedia({
|
| 208 |
audio: {
|
|
|
|
| 230 |
vadInterval = setInterval(vadTick, VAD_POLL_MS);
|
| 231 |
}
|
| 232 |
|
|
|
|
| 233 |
function stopListening() {
|
| 234 |
clearInterval(vadInterval);
|
| 235 |
clearTimeout(silenceTimer);
|
| 236 |
vadInterval = silenceTimer = null;
|
| 237 |
|
| 238 |
+
if (isSpeaking) stopRecorder(true);
|
|
|
|
| 239 |
stopAllAudio();
|
| 240 |
|
| 241 |
micStream?.getTracks().forEach((t) => t.stop());
|
|
|
|
| 246 |
setMicStatus('off');
|
| 247 |
}
|
| 248 |
|
|
|
|
| 249 |
function vadTick() {
|
| 250 |
if (!analyser) return;
|
| 251 |
|
|
|
|
| 257 |
const speaking = db > SILENCE_THRESHOLD_DB;
|
| 258 |
|
| 259 |
if (speaking) {
|
|
|
|
| 260 |
if (isProcessing) {
|
| 261 |
+
console.log('[VAD] Barge-in — stopping TTS.');
|
| 262 |
stopAllAudio();
|
| 263 |
isProcessing = false;
|
| 264 |
}
|
| 265 |
|
|
|
|
| 266 |
clearTimeout(silenceTimer);
|
| 267 |
silenceTimer = null;
|
| 268 |
|
| 269 |
if (!isSpeaking) {
|
| 270 |
isSpeaking = true;
|
| 271 |
+
_playbackCancelled = false;
|
| 272 |
startRecorder();
|
| 273 |
setMicStatus('recording');
|
| 274 |
}
|
|
|
|
| 278 |
silenceTimer = null;
|
| 279 |
isSpeaking = false;
|
| 280 |
|
|
|
|
|
|
|
| 281 |
isProcessing = true;
|
| 282 |
+
_playbackCancelled = false;
|
| 283 |
+
|
| 284 |
stopRecorder(false);
|
| 285 |
setMicStatus('processing');
|
| 286 |
}, SILENCE_TIMEOUT_MS);
|
|
|
|
| 288 |
}
|
| 289 |
}
|
| 290 |
|
|
|
|
| 291 |
function startRecorder() {
|
| 292 |
if (!micStream) return;
|
| 293 |
audioChunks = [];
|
|
|
|
| 332 |
mediaRecorder = null;
|
| 333 |
}
|
| 334 |
|
|
|
|
| 335 |
function setMicStatus(state) {
|
| 336 |
const labels = {
|
| 337 |
off: '🎤 Start Voice',
|
|
|
|
| 348 |
div.className = `message ${sender}`;
|
| 349 |
|
| 350 |
if (sender === 'ai' && typeof marked !== 'undefined') {
|
|
|
|
| 351 |
div.innerHTML = marked.parse(text);
|
| 352 |
} else {
|
| 353 |
div.textContent = text;
|
services/streaming.py
CHANGED
|
@@ -1,172 +1,237 @@
|
|
| 1 |
-
"""
|
| 2 |
-
services/streaming.py — Parallel + ordered TTS streamer
|
| 3 |
-
|
| 4 |
-
Fixes applied
|
| 5 |
-
─────────────
|
| 6 |
-
1. BUFFER RACE — self.buffer is now only mutated while holding
|
| 7 |
-
self._flush_lock, so add_token() and _schedule_flush() can never
|
| 8 |
-
interleave partial writes.
|
| 9 |
-
2. CANCELLATION — ParallelTTSStreamer.cancel() drops all pending tasks
|
| 10 |
-
and poisons the queue with a sentinel so stream_audio() exits
|
| 11 |
-
immediately. app.py calls cancel() when the user starts speaking
|
| 12 |
-
mid-playback, giving true barge-in / interrupt behaviour.
|
| 13 |
-
3. Markdown stripping (_clean_for_tts) is unchanged.
|
| 14 |
-
4. Audio ordering guarantee is unchanged (task-chain pattern).
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
from __future__ import annotations
|
| 18 |
|
| 19 |
import asyncio
|
| 20 |
import re
|
|
|
|
|
|
|
| 21 |
|
| 22 |
import edge_tts
|
| 23 |
|
| 24 |
-
VOICE
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
-
# ── Markdown → plain text ──────────────────────────────────────────────────────
|
| 31 |
def _clean_for_tts(text: str) -> str:
|
| 32 |
-
text = re.sub(r"\*{1,3}",
|
| 33 |
-
text = re.sub(r"#+\s*",
|
| 34 |
-
text = re.sub(r"^\s*[-•]\s*",
|
| 35 |
text = re.sub(r"^\s*[\d০-৯]+[.)]\s*", "", text, flags=re.MULTILINE)
|
| 36 |
-
text = re.sub(r"`+",
|
| 37 |
-
text = re.sub(r"\n{2,}",
|
| 38 |
return text.strip()
|
| 39 |
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
class ParallelTTSStreamer:
|
| 43 |
"""
|
| 44 |
-
Collects LLM tokens → prosodic chunks → parallel edge-tts
|
| 45 |
-
|
| 46 |
|
| 47 |
Usage
|
| 48 |
─────
|
| 49 |
streamer = ParallelTTSStreamer()
|
| 50 |
|
| 51 |
-
|
| 52 |
await streamer.add_token(token)
|
| 53 |
-
await streamer.flush()
|
| 54 |
|
| 55 |
-
|
| 56 |
-
async for
|
| 57 |
-
await ws.send_bytes(
|
| 58 |
|
| 59 |
-
|
| 60 |
await streamer.cancel()
|
| 61 |
"""
|
| 62 |
|
| 63 |
def __init__(self, voice: str = VOICE) -> None:
|
| 64 |
self.voice = voice
|
| 65 |
self.buffer = ""
|
| 66 |
-
self.queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
| 67 |
-
self._prev_task: asyncio.Task | None = None
|
| 68 |
-
self._flush_lock = asyncio.Lock()
|
| 69 |
self._cancelled = False
|
| 70 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
# ── Token intake ───────────────────────────────────────────────────────────
|
| 73 |
async def add_token(self, token: str) -> None:
|
| 74 |
if not token or self._cancelled:
|
| 75 |
return
|
| 76 |
|
| 77 |
-
|
| 78 |
-
async with self._flush_lock:
|
| 79 |
-
self.buffer += token
|
| 80 |
-
should_flush = (
|
| 81 |
-
any(ch in FLUSH_TRIGGERS for ch in token)
|
| 82 |
-
or len(self.buffer) >= FLUSH_LEN
|
| 83 |
-
)
|
| 84 |
|
| 85 |
-
if
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
async def
|
| 90 |
if self._cancelled:
|
|
|
|
| 91 |
return
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
self.buffer = ""
|
| 96 |
|
| 97 |
text = _clean_for_tts(raw)
|
| 98 |
if len(text) < MIN_CHARS:
|
| 99 |
return
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
self._tasks.append(task)
|
| 105 |
-
task.add_done_callback(
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
if
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
# ── Flush remaining buffer ─────────────────────────────────────────────────
|
| 135 |
async def flush(self) -> None:
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
| 146 |
async def cancel(self) -> None:
|
| 147 |
"""
|
| 148 |
-
Immediately abort all in-flight
|
| 149 |
-
|
|
|
|
| 150 |
"""
|
| 151 |
self._cancelled = True
|
| 152 |
|
| 153 |
-
# Cancel all pending asyncio tasks
|
| 154 |
for task in list(self._tasks):
|
| 155 |
task.cancel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
# Drain and poison the queue so stream_audio() exits
|
| 158 |
-
while not self.queue.empty():
|
| 159 |
-
try:
|
| 160 |
-
self.queue.get_nowait()
|
| 161 |
-
except asyncio.QueueEmpty:
|
| 162 |
-
break
|
| 163 |
-
await self.queue.put(None) # sentinel → stream_audio exits
|
| 164 |
|
| 165 |
-
# ── Audio consumer ─────────────────────────────────────────────────────────
|
| 166 |
async def stream_audio(self):
|
| 167 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
while True:
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
import re
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
|
| 8 |
import edge_tts
|
| 9 |
|
| 10 |
+
VOICE = "bn-BD-NabanitaNeural"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
FIRST_FLUSH_BOUNDARY_MIN = 25
|
| 14 |
+
FIRST_FLUSH_HARD = 70
|
| 15 |
+
SUBSEQUENT_FLUSH_BOUNDARY_MIN = 40
|
| 16 |
+
SUBSEQUENT_FLUSH_HARD = 110
|
| 17 |
+
MIN_CHARS = 4
|
| 18 |
+
|
| 19 |
+
SENTENCE_BOUNDARIES = frozenset(".!?।॥\n")
|
| 20 |
+
CLAUSE_BOUNDARIES = frozenset(",;:—–")
|
| 21 |
|
| 22 |
|
|
|
|
| 23 |
def _clean_for_tts(text: str) -> str:
|
| 24 |
+
text = re.sub(r"\*{1,3}", "", text)
|
| 25 |
+
text = re.sub(r"#+\s*", "", text)
|
| 26 |
+
text = re.sub(r"^\s*[-•]\s*", "", text, flags=re.MULTILINE)
|
| 27 |
text = re.sub(r"^\s*[\d০-৯]+[.)]\s*", "", text, flags=re.MULTILINE)
|
| 28 |
+
text = re.sub(r"`+", "", text)
|
| 29 |
+
text = re.sub(r"\n{2,}", "\n", text)
|
| 30 |
return text.strip()
|
| 31 |
|
| 32 |
|
| 33 |
+
def _should_flush(buffer: str, first_chunk: bool) -> bool:
|
| 34 |
+
"""
|
| 35 |
+
Return True if the buffer is ready to be sent to TTS.
|
| 36 |
+
|
| 37 |
+
Flushing strategy (per chunk):
|
| 38 |
+
1. If we hit a sentence boundary and have enough chars → flush.
|
| 39 |
+
2. If we're at the hard limit (even mid-sentence) → flush.
|
| 40 |
+
3. If we hit a clause boundary near the hard limit → flush early.
|
| 41 |
+
"""
|
| 42 |
+
n = len(buffer)
|
| 43 |
+
boundary_min = FIRST_FLUSH_BOUNDARY_MIN if first_chunk else SUBSEQUENT_FLUSH_BOUNDARY_MIN
|
| 44 |
+
hard_limit = FIRST_FLUSH_HARD if first_chunk else SUBSEQUENT_FLUSH_HARD
|
| 45 |
+
|
| 46 |
+
if n == 0:
|
| 47 |
+
return False
|
| 48 |
+
if n >= hard_limit:
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
last_char = buffer[-1] if buffer else ""
|
| 52 |
+
if last_char in SENTENCE_BOUNDARIES and n >= boundary_min:
|
| 53 |
+
return True
|
| 54 |
+
if last_char in CLAUSE_BOUNDARIES and n >= hard_limit * 0.8:
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class _AudioSlot:
|
| 62 |
+
"""Holds synthesised audio for one TTS chunk. Delivered in slot order."""
|
| 63 |
+
index: int
|
| 64 |
+
ready: asyncio.Event = field(default_factory=asyncio.Event)
|
| 65 |
+
chunks: list[bytes] = field(default_factory=list)
|
| 66 |
+
error: bool = False
|
| 67 |
+
|
| 68 |
+
|
| 69 |
class ParallelTTSStreamer:
|
| 70 |
"""
|
| 71 |
+
Collects LLM tokens → prosodic sentence chunks → parallel edge-tts
|
| 72 |
+
synthesis → slot-ordered audio delivery.
|
| 73 |
|
| 74 |
Usage
|
| 75 |
─────
|
| 76 |
streamer = ParallelTTSStreamer()
|
| 77 |
|
| 78 |
+
|
| 79 |
await streamer.add_token(token)
|
| 80 |
+
await streamer.flush()
|
| 81 |
|
| 82 |
+
|
| 83 |
+
async for audio_bytes in streamer.stream_audio():
|
| 84 |
+
await ws.send_bytes(audio_bytes)
|
| 85 |
|
| 86 |
+
|
| 87 |
await streamer.cancel()
|
| 88 |
"""
|
| 89 |
|
| 90 |
def __init__(self, voice: str = VOICE) -> None:
|
| 91 |
self.voice = voice
|
| 92 |
self.buffer = ""
|
|
|
|
|
|
|
|
|
|
| 93 |
self._cancelled = False
|
| 94 |
+
self._first_chunk = True
|
| 95 |
+
self._slot_index = 0
|
| 96 |
+
self._slots: list[_AudioSlot] = []
|
| 97 |
+
self._slots_lock = asyncio.Lock()
|
| 98 |
+
self._tasks: list[asyncio.Task] = []
|
| 99 |
+
self._done_event = asyncio.Event()
|
| 100 |
+
|
| 101 |
|
|
|
|
| 102 |
async def add_token(self, token: str) -> None:
|
| 103 |
if not token or self._cancelled:
|
| 104 |
return
|
| 105 |
|
| 106 |
+
self.buffer += token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
if _should_flush(self.buffer, self._first_chunk):
|
| 109 |
+
self._first_chunk = False
|
| 110 |
+
await self._schedule_chunk()
|
| 111 |
|
| 112 |
+
|
| 113 |
+
async def _schedule_chunk(self) -> None:
|
| 114 |
if self._cancelled:
|
| 115 |
+
self.buffer = ""
|
| 116 |
return
|
| 117 |
|
| 118 |
+
raw = self.buffer.strip()
|
| 119 |
+
self.buffer = ""
|
|
|
|
| 120 |
|
| 121 |
text = _clean_for_tts(raw)
|
| 122 |
if len(text) < MIN_CHARS:
|
| 123 |
return
|
| 124 |
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
async with self._slots_lock:
|
| 128 |
+
slot = _AudioSlot(index=self._slot_index)
|
| 129 |
+
self._slot_index += 1
|
| 130 |
+
self._slots.append(slot)
|
| 131 |
+
|
| 132 |
+
task = asyncio.create_task(self._synthesise(text, slot))
|
| 133 |
self._tasks.append(task)
|
| 134 |
+
task.add_done_callback(
|
| 135 |
+
lambda t: self._tasks.remove(t) if t in self._tasks else None
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
async def _synthesise(self, text: str, slot: _AudioSlot) -> None:
|
| 140 |
+
if self._cancelled:
|
| 141 |
+
slot.error = True
|
| 142 |
+
slot.ready.set()
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
communicate = edge_tts.Communicate(text, self.voice)
|
| 147 |
+
async for chunk in communicate.stream():
|
| 148 |
+
if self._cancelled:
|
| 149 |
+
slot.error = True
|
| 150 |
+
slot.ready.set()
|
| 151 |
+
return
|
| 152 |
+
if chunk["type"] == "audio":
|
| 153 |
+
slot.chunks.append(chunk["data"])
|
| 154 |
+
except asyncio.CancelledError:
|
| 155 |
+
slot.error = True
|
| 156 |
+
except Exception as exc:
|
| 157 |
+
print(f"[TTS] edge-tts error for '{text[:50]}': {exc}")
|
| 158 |
+
slot.error = True
|
| 159 |
+
finally:
|
| 160 |
+
slot.ready.set()
|
| 161 |
+
|
| 162 |
+
|
|
|
|
| 163 |
async def flush(self) -> None:
|
| 164 |
+
|
| 165 |
+
if self.buffer.strip():
|
| 166 |
+
await self._schedule_chunk()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if self._tasks:
|
| 170 |
+
await asyncio.gather(*self._tasks, return_exceptions=True)
|
| 171 |
+
|
| 172 |
+
self._done_event.set()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
async def cancel(self) -> None:
|
| 176 |
"""
|
| 177 |
+
Immediately abort all in-flight synthesis tasks.
|
| 178 |
+
Marks all pending slots as errored so stream_audio() exits promptly.
|
| 179 |
+
Idempotent.
|
| 180 |
"""
|
| 181 |
self._cancelled = True
|
| 182 |
|
|
|
|
| 183 |
for task in list(self._tasks):
|
| 184 |
task.cancel()
|
| 185 |
+
self._tasks.clear()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
async with self._slots_lock:
|
| 189 |
+
for slot in self._slots:
|
| 190 |
+
if not slot.ready.is_set():
|
| 191 |
+
slot.error = True
|
| 192 |
+
slot.ready.set()
|
| 193 |
+
|
| 194 |
+
self._done_event.set()
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
|
|
|
| 197 |
async def stream_audio(self):
|
| 198 |
+
"""
|
| 199 |
+
Yields ordered audio bytes. Slots are consumed in creation order;
|
| 200 |
+
each slot is awaited individually so synthesis of slot N+1 can
|
| 201 |
+
proceed in parallel while the consumer is yielding slot N's bytes.
|
| 202 |
+
"""
|
| 203 |
+
delivered = 0
|
| 204 |
+
|
| 205 |
while True:
|
| 206 |
+
|
| 207 |
+
async with self._slots_lock:
|
| 208 |
+
if delivered < len(self._slots):
|
| 209 |
+
slot = self._slots[delivered]
|
| 210 |
+
else:
|
| 211 |
+
slot = None
|
| 212 |
+
|
| 213 |
+
if slot is None:
|
| 214 |
+
|
| 215 |
+
if self._done_event.is_set():
|
| 216 |
+
break
|
| 217 |
+
await asyncio.sleep(0.005)
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
await slot.ready.wait()
|
| 222 |
+
|
| 223 |
+
if not self._cancelled and not slot.error:
|
| 224 |
+
for audio_bytes in slot.chunks:
|
| 225 |
+
yield audio_bytes
|
| 226 |
+
|
| 227 |
+
delivered += 1
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def reset(self) -> None:
|
| 231 |
+
self._cancelled = False
|
| 232 |
+
self._first_chunk = True
|
| 233 |
+
self.buffer = ""
|
| 234 |
+
self._slot_index = 0
|
| 235 |
+
self._slots.clear()
|
| 236 |
+
self._tasks.clear()
|
| 237 |
+
self._done_event.clear()
|