rakib72642 commited on
Commit
4d2289b
·
1 Parent(s): b70a952

fluent communication part :: rakib

Browse files
Files changed (6) hide show
  1. .env +2 -0
  2. app.py +8 -27
  3. core/backend.py +272 -251
  4. frontend/index.html +1 -0
  5. frontend/script.js +97 -102
  6. services/streaming.py +178 -113
.env CHANGED
@@ -5,6 +5,8 @@ LANGCHAIN_ENDPOINT='https://api.smith.langchain.com'
5
  LANGCHAIN_API_KEY='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'
6
  LANGCHAIN_PROJECT='Default'
7
 
 
 
8
  # TWILIO_ACCOUNT_SID="ACfafc0d2d007bdf14b21bb3e14a7a7b31"
9
  # TWILIO_AUTH_TOKEN="ed15fa98748c8c3d3d02cb54e431a187"
10
  # TWILIO_PHONE_NUMBER="+14343375085"
 
5
  LANGCHAIN_API_KEY='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'
6
  LANGCHAIN_PROJECT='Default'
7
 
8
+ GOOGLE_API_KEY="AIzaSyA9sqz4YKQHKXR9TU1imw0DPOghzHOMiBo"
9
+
10
  # TWILIO_ACCOUNT_SID="ACfafc0d2d007bdf14b21bb3e14a7a7b31"
11
  # TWILIO_AUTH_TOKEN="ed15fa98748c8c3d3d02cb54e431a187"
12
  # TWILIO_PHONE_NUMBER="+14343375085"
app.py CHANGED
@@ -1,19 +1,3 @@
1
- """
2
- app.py — FastAPI entry point
3
-
4
- Fixes applied
5
- ─────────────
6
- 1. STT is now fully async (stt.transcribe is a coroutine) — no more
7
- asyncio.to_thread wrapper needed in the WS handler.
8
- 2. BARGE-IN: when the client sends a new audio blob while TTS is still
9
- playing, the running tts_streamer is cancelled before starting a new
10
- turn. The client enforces isProcessing so this should be rare, but
11
- the server now handles it gracefully.
12
- 3. Per-session cancel token stored in `_active_streamer` so any new
13
- utterance from the same WS cleanly aborts the previous one.
14
- 4. All other logic (ping/pong, safe send helpers, chat WS) is unchanged.
15
- """
16
-
17
  import asyncio
18
  import json
19
  import os
@@ -55,7 +39,6 @@ async def root():
55
  return HTMLResponse("<h2>index.html not found</h2>", status_code=404)
56
 
57
 
58
- # ── Helpers ────────────────────────────────────────────────────────────────────
59
  def _ws_open(ws: WebSocket) -> bool:
60
  return ws.client_state == WebSocketState.CONNECTED
61
 
@@ -80,7 +63,6 @@ async def _safe_bytes(ws: WebSocket, data: bytes) -> bool:
80
  return False
81
 
82
 
83
- # ── Text chat WebSocket ────────────────────────────────────────────────────────
84
  @app.websocket("/ws/chat")
85
  async def ws_chat(ws: WebSocket):
86
  await ws.accept()
@@ -118,7 +100,6 @@ async def ws_chat(ws: WebSocket):
118
  print(f"[CHAT] WS error: {exc}")
119
 
120
 
121
- # ── Voice WebSocket ────────────────────────────────────────────────────────────
122
  @app.websocket("/ws/voice")
123
  async def ws_voice(ws: WebSocket):
124
  await ws.accept()
@@ -126,7 +107,7 @@ async def ws_voice(ws: WebSocket):
126
 
127
  stt = STTProcessor()
128
  user_id = "voice_user"
129
- _active_streamer: ParallelTTSStreamer | None = None # barge-in handle
130
 
131
  try:
132
  while True:
@@ -146,18 +127,18 @@ async def ws_voice(ws: WebSocket):
146
  print(f"[VOICE] Receive error: {exc}")
147
  break
148
 
149
- # ── Audio blob from client VAD ──────────────────────────────────
150
  if "bytes" in data and data["bytes"]:
151
  audio_bytes = data["bytes"]
152
  print(f"[VOICE] Received utterance: {len(audio_bytes):,} bytes")
153
 
154
- # ── Barge-in: cancel any running TTS turn ───────────────────
155
  if _active_streamer is not None:
156
  print("[VOICE] Barge-in — cancelling previous TTS.")
157
  await _active_streamer.cancel()
158
  _active_streamer = None
159
 
160
- # 1. STT — now a native coroutine (GPU semaphore inside)
161
  transcript = await stt.transcribe(audio_bytes)
162
 
163
  if not transcript:
@@ -172,7 +153,7 @@ async def ws_voice(ws: WebSocket):
172
  if not await _safe_text(ws, {"type": "stt", "text": transcript}):
173
  break
174
 
175
- # 2. AI + TTS pipeline
176
  tts_streamer = ParallelTTSStreamer()
177
  _active_streamer = tts_streamer
178
 
@@ -198,17 +179,17 @@ async def ws_voice(ws: WebSocket):
198
  await asyncio.gather(run_ai_and_tts(), stream_tts_audio())
199
  _active_streamer = None
200
 
201
- # Signal end-of-turn → client resumes VAD
202
  await _safe_text(ws, {"type": "end"})
203
 
204
- # ── Control messages ────────────────────────────────────────────
205
  elif "text" in data and data["text"]:
206
  try:
207
  msg = json.loads(data["text"])
208
  if msg.get("type") == "ping":
209
  await _safe_text(ws, {"type": "pong"})
210
 
211
- # Client can send {"type":"cancel"} to abort TTS mid-turn
212
  elif msg.get("type") == "cancel":
213
  if _active_streamer is not None:
214
  print("[VOICE] Client cancel signal received.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
  import json
3
  import os
 
39
  return HTMLResponse("<h2>index.html not found</h2>", status_code=404)
40
 
41
 
 
42
  def _ws_open(ws: WebSocket) -> bool:
43
  return ws.client_state == WebSocketState.CONNECTED
44
 
 
63
  return False
64
 
65
 
 
66
  @app.websocket("/ws/chat")
67
  async def ws_chat(ws: WebSocket):
68
  await ws.accept()
 
100
  print(f"[CHAT] WS error: {exc}")
101
 
102
 
 
103
  @app.websocket("/ws/voice")
104
  async def ws_voice(ws: WebSocket):
105
  await ws.accept()
 
107
 
108
  stt = STTProcessor()
109
  user_id = "voice_user"
110
+ _active_streamer: ParallelTTSStreamer | None = None
111
 
112
  try:
113
  while True:
 
127
  print(f"[VOICE] Receive error: {exc}")
128
  break
129
 
130
+
131
  if "bytes" in data and data["bytes"]:
132
  audio_bytes = data["bytes"]
133
  print(f"[VOICE] Received utterance: {len(audio_bytes):,} bytes")
134
 
135
+
136
  if _active_streamer is not None:
137
  print("[VOICE] Barge-in — cancelling previous TTS.")
138
  await _active_streamer.cancel()
139
  _active_streamer = None
140
 
141
+
142
  transcript = await stt.transcribe(audio_bytes)
143
 
144
  if not transcript:
 
153
  if not await _safe_text(ws, {"type": "stt", "text": transcript}):
154
  break
155
 
156
+
157
  tts_streamer = ParallelTTSStreamer()
158
  _active_streamer = tts_streamer
159
 
 
179
  await asyncio.gather(run_ai_and_tts(), stream_tts_audio())
180
  _active_streamer = None
181
 
182
+
183
  await _safe_text(ws, {"type": "end"})
184
 
185
+
186
  elif "text" in data and data["text"]:
187
  try:
188
  msg = json.loads(data["text"])
189
  if msg.get("type") == "ping":
190
  await _safe_text(ws, {"type": "pong"})
191
 
192
+
193
  elif msg.get("type") == "cancel":
194
  if _active_streamer is not None:
195
  print("[VOICE] Client cancel signal received.")
core/backend.py CHANGED
@@ -1,36 +1,43 @@
1
- from langgraph.graph import StateGraph, START, END
2
- from typing import TypedDict, Annotated
3
- from langchain_core.messages import BaseMessage
4
- from langgraph.graph.message import add_messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
6
- from langchain_ollama import ChatOllama
 
7
  from langgraph.prebuilt import ToolNode, tools_condition
8
- from langchain_community.tools import DuckDuckGoSearchRun
9
- from langchain_core.tools import tool
10
- from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, RemoveMessage, SystemMessage
11
- import aiosqlite, uuid, os, httpx, asyncio
12
  from twilio.rest import Client
13
- from dotenv import load_dotenv
14
- import json, pytz
15
- from datetime import datetime
16
 
17
- ######################### STATE #########################
 
 
 
18
  class ChatState(TypedDict):
19
- messages: Annotated[list[BaseMessage], add_messages]
20
  summary: str
21
 
22
- ######################### TOOLS #########################
23
- # After imports, before STATE class
24
- def get_db_path():
 
 
25
  return os.path.join(os.path.dirname(__file__), "daa.db")
26
 
27
- def send_sms(to_number: str, message: str):
28
- client = Client(os.getenv("TWILIO_ACCOUNT_SID"), os.getenv("TWILIO_AUTH_TOKEN"))
29
- client.messages.create(
30
- body=message,
31
- from_=os.getenv("TWILIO_PHONE_NUMBER"),
32
- to=to_number
33
- )
34
 
35
  def format_bd_number(num: str) -> str:
36
  num = num.strip().replace(" ", "")
@@ -38,36 +45,50 @@ def format_bd_number(num: str) -> str:
38
  return "+88" + num
39
  if num.startswith("8801"):
40
  return "+" + num
41
- return num # already formatted or unknown
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
43
  @tool
44
  def get_bd_time() -> str:
45
- """
46
- Get current Bangladesh time (Asia/Dhaka) with weekday name
47
- """
48
- tz = pytz.timezone("Asia/Dhaka")
49
  now = datetime.now(tz)
50
  return now.strftime("%Y-%m-%d %H:%M:%S (%A, Bangladesh Time)")
51
 
 
52
  @tool
53
- async def search_doctor(name: str = "", category: str = "", visiting_days: str = "") -> str:
 
 
 
 
54
  """
55
- Search doctors by name, category, or visiting_days from SQLite database.
56
- Any combination of filters is supported (OR logic for each field).
57
  """
58
- db_path = get_db_path()
59
- query = "SELECT * FROM doctors WHERE 1=1"
60
- params = []
61
- conditions = []
62
 
63
  if name:
64
  conditions.append("LOWER(doctor_name) LIKE ?")
65
  params.append(f"%{name.lower()}%")
66
-
67
  if category:
68
  conditions.append("LOWER(category) LIKE ?")
69
  params.append(f"%{category.lower()}%")
70
-
71
  if visiting_days:
72
  conditions.append("LOWER(visiting_days) LIKE ?")
73
  params.append(f"%{visiting_days.lower()}%")
@@ -78,119 +99,89 @@ async def search_doctor(name: str = "", category: str = "", visiting_days: str =
78
  async with aiosqlite.connect(db_path) as db:
79
  db.row_factory = aiosqlite.Row
80
  cursor = await db.execute(query, params)
81
- rows = await cursor.fetchall()
82
 
83
  if not rows:
84
- return json.dumps({
85
- "success": False,
86
- "message": "No doctors found matching your search.",
87
- "data": []
88
- })
89
 
90
- return json.dumps({
91
- "success": True,
92
- "count": len(rows),
93
- "data": [dict(r) for r in rows]
94
- })
95
 
96
  @tool
97
  async def search_appointment_by_phone(patient_num: str) -> str:
98
- """
99
- Search all appointments using patient phone number.
100
- """
101
- db_path = get_db_path()
102
  patient_num = format_bd_number(patient_num)
103
 
104
  async with aiosqlite.connect(db_path) as db:
105
  db.row_factory = aiosqlite.Row
106
-
107
- cursor = await db.execute("""
108
- SELECT * FROM patients
109
- WHERE patient_num = ?
110
- ORDER BY visiting_date ASC
111
- """, (patient_num,))
112
-
113
  rows = await cursor.fetchall()
114
 
115
  if not rows:
116
  return json.dumps({
117
  "success": False,
118
  "message": "No appointments found for this phone number.",
119
- "data": []
120
  })
 
121
 
122
- return json.dumps({
123
- "success": True,
124
- "count": len(rows),
125
- "data": [dict(r) for r in rows]
126
- })
127
 
128
  @tool
129
- async def book_appointment(doctor_id: int, patient_name: str, patient_age: str, patient_num: str, visiting_date: str) -> str:
 
 
 
 
 
 
130
  """
131
  Book a doctor appointment and save it to the patients table.
132
 
133
  Args:
134
- doctor_id: Doctor's ID from search_doctor results.
135
- patient_name: Full name of the patient.
136
- patient_age: Age of the patient (e.g. "32").
137
- patient_num: Contact phone number of the patient.
138
  visiting_date: Date of visit in YYYY-MM-DD format (e.g. 2025-06-15).
139
-
140
- Returns a booking confirmation with the new record ID.
141
  """
142
- db_path = get_db_path()
 
143
 
144
  async with aiosqlite.connect(db_path) as db:
145
  db.row_factory = aiosqlite.Row
146
 
147
- patient_num = format_bd_number(patient_num)
148
-
149
- # Verify doctor exists
150
  cursor = await db.execute("SELECT * FROM doctors WHERE id = ?", (doctor_id,))
151
  doctor = await cursor.fetchone()
152
  if not doctor:
153
  return f"No doctor found with ID {doctor_id}. Please search for a doctor first."
154
 
155
- doctor_data = dict(doctor)
156
- doctor_name = doctor_data.get("doctor_name", "Unknown")
157
- doctor_category = doctor_data.get("doctor_category", "Unknown")
158
 
159
- # Check for conflicting booking (same doctor + same date)
160
  cursor = await db.execute(
161
  """SELECT id FROM patients
162
- WHERE doctor_name = ? AND visiting_date = ? AND patient_num = ?""",
163
  (doctor_name, visiting_date, patient_num),
164
  )
165
- conflict = await cursor.fetchone()
166
- if conflict:
167
  return (
168
  f"A booking for {patient_name} with Dr. {doctor_name} "
169
  f"on {visiting_date} already exists."
170
  )
171
 
172
- # Insert into patients table
173
- cursor = await db.execute(
174
- """INSERT INTO patients (doctor_name, doctor_category, patient_name, patient_age, patient_num, visiting_date)
175
- VALUES (?, ?, ?, ?, ?, ?)""",
176
  (doctor_name, doctor_category, patient_name, patient_age, patient_num, visiting_date),
177
  )
178
  await db.commit()
179
 
180
- # Send SMS confirmation
181
- sms_message = (
182
- f"✅ Appointment Confirmed!\n"
183
- f"Doctor : {doctor_name}\n"
184
- f"Patient : {patient_name}\n"
185
- f"Visit Date : {visiting_date}\n"
186
- f"Please arrive 10 minutes early."
187
- )
188
- # try:
189
- # send_sms(to_number=patient_num, message=sms_message)
190
- # sms_status = "📱 SMS confirmation sent."
191
- # except Exception as e:
192
- # sms_status = f"⚠️ SMS failed: {str(e)}"
193
-
194
  return (
195
  f"✅ Appointment Booked!\n"
196
  f"━━━━━━━━━━━━━━━━━━━━━━\n"
@@ -201,182 +192,203 @@ async def book_appointment(doctor_id: int, patient_name: str, patient_age: str,
201
  f"Contact : {patient_num}\n"
202
  f"━━━━━━━━━━━━━━━━━━━━━━\n"
203
  f"Please arrive 10 minutes early."
204
- # f"{sms_status}"
205
  )
206
 
 
 
207
  async def delete_appointment(patient_num: str, doctor_name: str) -> str:
208
- """
209
- Delete an appointment using patient phone number and doctor name.
210
- """
211
- db_path = get_db_path()
212
- # normalize phone number
213
  patient_num = format_bd_number(patient_num)
214
 
215
  async with aiosqlite.connect(db_path) as db:
216
  db.row_factory = aiosqlite.Row
217
 
218
- # check if appointment exists first
219
- cursor = await db.execute("""
220
- SELECT * FROM patients
221
- WHERE patient_num = ?
222
- AND LOWER(doctor_name) = LOWER(?)
223
- """, (patient_num, doctor_name))
224
-
225
- row = await cursor.fetchone()
226
- if not row:
227
- return json.dumps({
228
- "success": False,
229
- "message": "No matching appointment found to delete."
230
- })
231
-
232
- # delete appointment
233
- await db.execute("""
234
- DELETE FROM patients
235
- WHERE patient_num = ?
236
- AND LOWER(doctor_name) = LOWER(?)
237
- """, (patient_num, doctor_name))
238
 
 
 
 
 
 
239
  await db.commit()
240
 
241
  return json.dumps({
242
  "success": True,
243
- "message": f"Appointment with Dr. {doctor_name} deleted successfully."
244
  })
245
 
246
- ######################### MAIN AGENT CLASS #########################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  class AIBackend:
248
- def __init__(self):
 
249
  load_dotenv()
250
- os.environ["LANGCHAIN_PROJECT"] = "Doctor Appointment Automation"
251
- self.llm = ChatOllama(model="gemma4:e4b", streaming=True) # qwen2.5:3b, gemma4:e4b
252
- self.tools = [search_doctor, book_appointment, get_bd_time, search_appointment_by_phone, delete_appointment]
253
- self.tool_node = ToolNode(self.tools)
 
 
 
 
 
 
 
 
 
 
 
254
  self.llm_with_tools = self.llm.bind_tools(self.tools)
255
 
256
- async def async_setup(self):
257
- db_path = os.path.join(os.path.dirname(__file__), "daa.db")
258
- self.conn = await aiosqlite.connect(db_path)
 
259
  self.checkpointer = AsyncSqliteSaver(self.conn)
260
- await self._create_user_table()
261
- self.graph = self._build_graph()
262
  self.summary_graph = self._build_summary_graph()
263
 
264
- async def _create_user_table(self):
265
  await self.conn.execute("""
266
  CREATE TABLE IF NOT EXISTS userid_threadid (
267
- userId TEXT UNIQUE NOT NULL,
268
  threadId TEXT UNIQUE NOT NULL
269
  )
270
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  await self.conn.commit()
272
 
273
- ######################### SUMMARIZE NODE #########################
274
  async def summarize_conversation(self, state: ChatState):
275
- existing_summary = state.get("summary", "")
276
  messages = state["messages"]
277
- prompt = (
278
- f"""
279
- You are maintaining a long-term conversation memory for a chatbot.
280
-
281
- Existing summary:
282
- {existing_summary}
283
-
284
- Update and extend the summary using ONLY the new conversation messages above.
285
-
286
- Instructions:
287
- - Preserve important existing context.
288
- - Add new facts, decisions, preferences, goals, issues, and ongoing tasks.
289
- - Keep technical details concise but meaningful.
290
- - Track unresolved problems or follow-up actions.
291
- - Avoid repetition and remove outdated or redundant information when appropriate.
292
- - Maintain chronological consistency.
293
- - Write the summary in clear bullet points.
294
- - Focus on information useful for future conversations and contextual continuity.
295
- - Do NOT include casual greetings or temporary small talk unless important.
296
- - Keep the summary compact but information-dense.
297
- """
298
- if existing_summary
299
- else
300
- """
301
- You are creating a long-term conversation memory summary for a chatbot.
302
-
303
- Summarize the conversation above.
304
-
305
- Instructions:
306
- - Capture important user information, goals, preferences, projects, and decisions.
307
- - Include technical issues, debugging progress, and solutions discussed.
308
- - Track ongoing tasks or unresolved questions.
309
- - Ignore casual greetings and low-value chatter.
310
- - Write concise, structured bullet points.
311
- - Keep the summary compact but highly informative for future context retention.
312
- """
313
- )
314
- messages_for_summary = messages + [HumanMessage(content=prompt)]
315
- response = await self.llm.ainvoke(messages_for_summary)
316
  return {
317
  "summary": response.content,
318
  "messages": [RemoveMessage(id=m.id) for m in messages[:-2]],
319
  }
320
 
321
- async def should_summarize(self, state: ChatState):
322
- if len(state["messages"]) > 10:
323
- return "summarize_node"
324
- return "chat_node"
325
 
326
- ######################### CHAT NODE #########################
327
  async def chat_node(self, state: ChatState):
328
- summary = state.get("summary", "")
 
 
 
 
 
 
 
 
 
329
  messages = state["messages"]
330
 
331
- print('#'*50)
332
  print(">>>>>>>>>> CHAT NODE START <<<<<<<<<<")
333
- if summary:
334
- print(f"[SUMMARY]:\n{summary}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  else:
336
- print("[NO SUMMARY YET]\n")
 
 
 
337
 
338
- print('$'*50)
339
- print("[MESSAGES]:")
340
- for m in messages:
341
- role = m.__class__.__name__
342
- print(f" [{role}]: {m.content[:200]}")
343
- print('$'*50,'\n')
344
-
345
- if summary:
346
- summary_message = SystemMessage(
347
- content=(
348
- "You are a Bangla voice assistant. You are provided with a condensed memory of previous conversations.\n\n"
349
- f"Conversation Memory:\n{summary}\n\n"
350
- "Instructions:\n"
351
- "- Always respond in Bangla (বাংলা)"
352
- "- Keep sentences short for speech"
353
- "- No English unless necessary"
354
- "- Use this memory as long-term conversational context.\n"
355
- "- Maintain continuity with the user's previous discussions, projects, goals, and preferences.\n"
356
- "- Prioritize recent and relevant information when generating responses.\n"
357
- "- Do not repeat the summary unless necessary.\n"
358
- "- If new information conflicts with old memory, prefer the latest context.\n"
359
- "- Use the memory naturally to improve personalization, reasoning, and follow-up responses.\n"
360
- "- Treat unresolved issues, active projects, and pending tasks as ongoing unless stated otherwise."
361
- )
362
- )
363
- messages = [summary_message] + messages
364
- response = await self.llm_with_tools.ainvoke(messages)
365
- print(f"Final [{response.__class__.__name__}]: {response.content[:200]}")
366
  print(">>>>>>>>>> CHAT NODE END <<<<<<<<<<")
367
- print('#'*50)
368
  return {"messages": [response]}
369
 
370
- ######################### GRAPH #########################
371
  def _build_graph(self):
372
  g = StateGraph(ChatState)
373
  g.add_node("chat_node", self.chat_node)
374
- g.add_node("tools", self.tool_node)
375
-
376
  g.add_edge(START, "chat_node")
377
  g.add_conditional_edges("chat_node", tools_condition)
378
  g.add_edge("tools", "chat_node")
379
-
380
  return g.compile(checkpointer=self.checkpointer)
381
 
382
  def _build_summary_graph(self):
@@ -386,40 +398,49 @@ class AIBackend:
386
  g.add_edge("summarize_node", END)
387
  return g.compile(checkpointer=self.checkpointer)
388
 
389
- ######################### STREAMING #########################
390
  async def ai_only_stream(self, initial_state: dict, config: dict):
391
- async for message_chunk, metadata in self.graph.astream(initial_state, config=config, stream_mode="messages"):
392
- if isinstance(message_chunk, AIMessage) and message_chunk.content:
393
- yield message_chunk.content
394
-
395
- # Auto Summarization Execute
396
- current_state = await self.graph.aget_state(config)
397
- if len(current_state.values.get("messages", [])) > 10:
 
 
 
 
 
 
 
 
 
398
  asyncio.create_task(
399
- self.summary_graph.ainvoke(current_state.values, config=config)
400
  )
401
- print('@'*20,'Summarization Execute','@'*20)
402
 
403
- ######################### THREAD ID #########################
404
  @staticmethod
405
  def generate_thread_id() -> str:
406
  return str(uuid.uuid4())
407
 
408
- ######################### RETRIEVE ALL THREADS #########################
409
- async def retrieve_all_threads(self):
410
- all_threads = set()
411
- async for checkpoint in self.checkpointer.alist(None):
412
- all_threads.add(checkpoint.config["configurable"]["thread_id"])
413
- return list(all_threads)
414
 
415
- ######################### MAIN ENTRY POINT #########################
416
  async def main(self, user_id: str, user_query: str):
 
417
  async with self.conn.execute(
418
- "SELECT userId, threadId FROM userid_threadid WHERE userId = ?", (user_id,)
419
  ) as cursor:
420
- result = await cursor.fetchone()
421
 
422
- if result is None:
423
  thread_id = user_id + self.generate_thread_id()
424
  await self.conn.execute(
425
  "INSERT INTO userid_threadid (userId, threadId) VALUES (?, ?)",
@@ -427,12 +448,12 @@ class AIBackend:
427
  )
428
  await self.conn.commit()
429
  else:
430
- thread_id = result[1]
431
 
432
  initial_state = {"messages": [HumanMessage(content=user_query)]}
433
  config = {
434
  "configurable": {"thread_id": thread_id},
435
- "metadata": {"thread_id": thread_id},
436
- "run_name": "chat_turn",
437
  }
438
  return self.ai_only_stream(initial_state, config)
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import os
6
+ import uuid
7
+
8
+ import aiosqlite
9
+ import pytz
10
+ from datetime import datetime
11
+ from dotenv import load_dotenv
12
+
13
+ from langchain_core.messages import (
14
+ AIMessage, AIMessageChunk, HumanMessage, RemoveMessage,
15
+ SystemMessage, ToolMessage,
16
+ )
17
+ from langchain_core.tools import tool
18
+ from langchain_google_genai import ChatGoogleGenerativeAI
19
  from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
20
+ from langgraph.graph import END, START, StateGraph
21
+ from langgraph.graph.message import add_messages
22
  from langgraph.prebuilt import ToolNode, tools_condition
 
 
 
 
23
  from twilio.rest import Client
24
+ from typing import Annotated, TypedDict
 
 
25
 
26
+
27
+ # ═══════════════════════════════════════════════════════════════════════════════
28
+ # STATE
29
+ # ═══════════════════════════════════════════════════════════════════════════════
30
  class ChatState(TypedDict):
31
+ messages: Annotated[list, add_messages]
32
  summary: str
33
 
34
+
35
+ # ═══════════════════════════════════════════════════════════════════════════════
36
+ # HELPERS
37
+ # ═══════════════════════════════════════════════════════════════════════════════
38
+ def get_db_path() -> str:
39
  return os.path.join(os.path.dirname(__file__), "daa.db")
40
 
 
 
 
 
 
 
 
41
 
42
  def format_bd_number(num: str) -> str:
43
  num = num.strip().replace(" ", "")
 
45
  return "+88" + num
46
  if num.startswith("8801"):
47
  return "+" + num
48
+ return num
49
+
50
+
51
+ def send_sms(to_number: str, message: str) -> None:
52
+ client = Client(os.getenv("TWILIO_ACCOUNT_SID"), os.getenv("TWILIO_AUTH_TOKEN"))
53
+ client.messages.create(
54
+ body=message,
55
+ from_=os.getenv("TWILIO_PHONE_NUMBER"),
56
+ to=to_number,
57
+ )
58
+
59
 
60
+ # ═══════════════════════════════════════════════════════════════════════════════
61
+ # TOOLS
62
+ # ═══════════════════════════════════════════════════════════════════════════════
63
  @tool
64
  def get_bd_time() -> str:
65
+ """Get current Bangladesh time (Asia/Dhaka) with weekday name."""
66
+ tz = pytz.timezone("Asia/Dhaka")
 
 
67
  now = datetime.now(tz)
68
  return now.strftime("%Y-%m-%d %H:%M:%S (%A, Bangladesh Time)")
69
 
70
+
71
  @tool
72
+ async def search_doctor(
73
+ name: str = "",
74
+ category: str = "",
75
+ visiting_days: str = "",
76
+ ) -> str:
77
  """
78
+ Search doctors by name, category, or visiting_days from the SQLite database.
79
+ Any combination of filters is supported (OR logic across fields).
80
  """
81
+ db_path = get_db_path()
82
+ query = "SELECT * FROM doctors WHERE 1=1"
83
+ params: list = []
84
+ conditions: list[str] = []
85
 
86
  if name:
87
  conditions.append("LOWER(doctor_name) LIKE ?")
88
  params.append(f"%{name.lower()}%")
 
89
  if category:
90
  conditions.append("LOWER(category) LIKE ?")
91
  params.append(f"%{category.lower()}%")
 
92
  if visiting_days:
93
  conditions.append("LOWER(visiting_days) LIKE ?")
94
  params.append(f"%{visiting_days.lower()}%")
 
99
  async with aiosqlite.connect(db_path) as db:
100
  db.row_factory = aiosqlite.Row
101
  cursor = await db.execute(query, params)
102
+ rows = await cursor.fetchall()
103
 
104
  if not rows:
105
+ return json.dumps({"success": False, "message": "No doctors found.", "data": []})
106
+
107
+ return json.dumps({"success": True, "count": len(rows), "data": [dict(r) for r in rows]})
 
 
108
 
 
 
 
 
 
109
 
110
  @tool
111
  async def search_appointment_by_phone(patient_num: str) -> str:
112
+ """Search all appointments using the patient's phone number."""
113
+ db_path = get_db_path()
 
 
114
  patient_num = format_bd_number(patient_num)
115
 
116
  async with aiosqlite.connect(db_path) as db:
117
  db.row_factory = aiosqlite.Row
118
+ cursor = await db.execute(
119
+ "SELECT * FROM patients WHERE patient_num = ? ORDER BY visiting_date ASC",
120
+ (patient_num,),
121
+ )
 
 
 
122
  rows = await cursor.fetchall()
123
 
124
  if not rows:
125
  return json.dumps({
126
  "success": False,
127
  "message": "No appointments found for this phone number.",
128
+ "data": [],
129
  })
130
+ return json.dumps({"success": True, "count": len(rows), "data": [dict(r) for r in rows]})
131
 
 
 
 
 
 
132
 
133
  @tool
134
+ async def book_appointment(
135
+ doctor_id: int,
136
+ patient_name: str,
137
+ patient_age: str,
138
+ patient_num: str,
139
+ visiting_date: str,
140
+ ) -> str:
141
  """
142
  Book a doctor appointment and save it to the patients table.
143
 
144
  Args:
145
+ doctor_id: Doctor's ID from search_doctor results.
146
+ patient_name: Full name of the patient.
147
+ patient_age: Age of the patient (e.g. "32").
148
+ patient_num: Contact phone number of the patient.
149
  visiting_date: Date of visit in YYYY-MM-DD format (e.g. 2025-06-15).
 
 
150
  """
151
+ db_path = get_db_path()
152
+ patient_num = format_bd_number(patient_num)
153
 
154
  async with aiosqlite.connect(db_path) as db:
155
  db.row_factory = aiosqlite.Row
156
 
 
 
 
157
  cursor = await db.execute("SELECT * FROM doctors WHERE id = ?", (doctor_id,))
158
  doctor = await cursor.fetchone()
159
  if not doctor:
160
  return f"No doctor found with ID {doctor_id}. Please search for a doctor first."
161
 
162
+ doctor_data = dict(doctor)
163
+ doctor_name = doctor_data.get("doctor_name", "Unknown")
164
+ doctor_category = doctor_data.get("category", "Unknown")
165
 
 
166
  cursor = await db.execute(
167
  """SELECT id FROM patients
168
+ WHERE doctor_name = ? AND visiting_date = ? AND patient_num = ?""",
169
  (doctor_name, visiting_date, patient_num),
170
  )
171
+ if await cursor.fetchone():
 
172
  return (
173
  f"A booking for {patient_name} with Dr. {doctor_name} "
174
  f"on {visiting_date} already exists."
175
  )
176
 
177
+ await db.execute(
178
+ """INSERT INTO patients
179
+ (doctor_name, doctor_category, patient_name, patient_age, patient_num, visiting_date)
180
+ VALUES (?, ?, ?, ?, ?, ?)""",
181
  (doctor_name, doctor_category, patient_name, patient_age, patient_num, visiting_date),
182
  )
183
  await db.commit()
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  return (
186
  f"✅ Appointment Booked!\n"
187
  f"━━━━━━━━━━━━━━━━━━━━━━\n"
 
192
  f"Contact : {patient_num}\n"
193
  f"━━━━━━━━━━━━━━━━━━━━━━\n"
194
  f"Please arrive 10 minutes early."
 
195
  )
196
 
197
+
198
+ @tool
199
  async def delete_appointment(patient_num: str, doctor_name: str) -> str:
200
+ """Delete an appointment using the patient's phone number and doctor name."""
201
+ db_path = get_db_path()
 
 
 
202
  patient_num = format_bd_number(patient_num)
203
 
204
  async with aiosqlite.connect(db_path) as db:
205
  db.row_factory = aiosqlite.Row
206
 
207
+ cursor = await db.execute(
208
+ """SELECT * FROM patients
209
+ WHERE patient_num = ? AND LOWER(doctor_name) = LOWER(?)""",
210
+ (patient_num, doctor_name),
211
+ )
212
+ if not await cursor.fetchone():
213
+ return json.dumps({"success": False, "message": "No matching appointment found."})
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ await db.execute(
216
+ """DELETE FROM patients
217
+ WHERE patient_num = ? AND LOWER(doctor_name) = LOWER(?)""",
218
+ (patient_num, doctor_name),
219
+ )
220
  await db.commit()
221
 
222
  return json.dumps({
223
  "success": True,
224
+ "message": f"Appointment with Dr. {doctor_name} deleted successfully.",
225
  })
226
 
227
+
228
+ # ═══════════════════════════════════════════════════════════════════════════════
229
+ # SYSTEM PROMPT
230
+ # ═══════════════════════════════════════════════════════════════════════════════
231
+ BASE_SYSTEM = (
232
+ "You are a helpful Bangla voice assistant for a doctor appointment system.\n"
233
+ "Rules:\n"
234
+ "- Always respond in Bangla (বাংলা).\n"
235
+ "- Keep sentences short and natural for text-to-speech playback.\n"
236
+ "- Avoid markdown, bullet points, or long lists in voice responses.\n"
237
+ "- Use tools when needed to search doctors or manage appointments.\n"
238
+ "- Be polite, concise, and clear.\n"
239
+ "- Do not use English unless a proper noun requires it.\n"
240
+ )
241
+
242
+ SUMMARY_SYSTEM = (
243
+ BASE_SYSTEM
244
+ + "\nYou also have a condensed memory of previous conversations:\n\n"
245
+ "{summary}\n\n"
246
+ "Use this memory for continuity. Do not repeat it unless asked."
247
+ )
248
+
249
+
250
+ # ═══════════════════════════════════════════════════════════════════════════════
251
+ # AGENT
252
+ # ═══════════════════════════════════════════════════════════════════════════════
253
  class AIBackend:
254
+
255
+ def __init__(self) -> None:
256
  load_dotenv()
257
+ os.environ.setdefault("LANGCHAIN_PROJECT", "Doctor Appointment Automation")
258
+
259
+ self.llm = ChatGoogleGenerativeAI(
260
+ model="gemini-2.0-flash",
261
+ temperature=0.3,
262
+ )
263
+
264
+ self.tools = [
265
+ search_doctor,
266
+ book_appointment,
267
+ get_bd_time,
268
+ search_appointment_by_phone,
269
+ delete_appointment,
270
+ ]
271
+ self.tool_node = ToolNode(self.tools)
272
  self.llm_with_tools = self.llm.bind_tools(self.tools)
273
 
274
+ # ── Setup ──────────────────────────────────────────────────────────────────
275
+ async def async_setup(self) -> None:
276
+ db_path = get_db_path()
277
+ self.conn = await aiosqlite.connect(db_path)
278
  self.checkpointer = AsyncSqliteSaver(self.conn)
279
+ await self._create_tables()
280
+ self.graph = self._build_graph()
281
  self.summary_graph = self._build_summary_graph()
282
 
283
+ async def _create_tables(self) -> None:
284
  await self.conn.execute("""
285
  CREATE TABLE IF NOT EXISTS userid_threadid (
286
+ userId TEXT UNIQUE NOT NULL,
287
  threadId TEXT UNIQUE NOT NULL
288
  )
289
+ """)
290
+ await self.conn.execute("""
291
+ CREATE TABLE IF NOT EXISTS doctors (
292
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
293
+ doctor_name TEXT NOT NULL,
294
+ category TEXT NOT NULL,
295
+ visiting_days TEXT NOT NULL,
296
+ chamber TEXT,
297
+ fee TEXT
298
+ )
299
+ """)
300
+ await self.conn.execute("""
301
+ CREATE TABLE IF NOT EXISTS patients (
302
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
303
+ doctor_name TEXT NOT NULL,
304
+ doctor_category TEXT,
305
+ patient_name TEXT NOT NULL,
306
+ patient_age TEXT,
307
+ patient_num TEXT NOT NULL,
308
+ visiting_date TEXT NOT NULL
309
+ )
310
+ """)
311
  await self.conn.commit()
312
 
313
+ # ── Summarise node ─────────────────────────────────────────────────────────
314
  async def summarize_conversation(self, state: ChatState):
315
+ existing = state.get("summary", "")
316
  messages = state["messages"]
317
+
318
+ if existing:
319
+ prompt = (
320
+ f"Existing summary:\n{existing}\n\n"
321
+ "Update the summary with the new messages above. "
322
+ "Keep it concise, bullet-pointed, and information-dense. "
323
+ "Preserve unresolved issues and ongoing tasks."
324
+ )
325
+ else:
326
+ prompt = (
327
+ "Summarise this conversation. "
328
+ "Capture goals, decisions, preferences, and unresolved questions. "
329
+ "Be concise and use bullet points."
330
+ )
331
+
332
+ response = await self.llm.ainvoke(messages + [HumanMessage(content=prompt)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  return {
334
  "summary": response.content,
335
  "messages": [RemoveMessage(id=m.id) for m in messages[:-2]],
336
  }
337
 
338
+ async def should_summarize(self, state: ChatState) -> str:
339
+ return "summarize_node" if len(state["messages"]) > 10 else "chat_node"
 
 
340
 
341
+ # ── Chat node — streaming version ──────────────────────────────────────────
342
  async def chat_node(self, state: ChatState):
343
+ """
344
+ Uses astream() instead of ainvoke() so that LangGraph's
345
+ stream_mode='messages' can relay individual tokens to the caller
346
+ as they arrive from Gemini, rather than waiting for the full
347
+ response to complete before yielding anything.
348
+
349
+ The streamed chunks are merged into a single AIMessage for the
350
+ graph state so checkpointing and tool detection work unchanged.
351
+ """
352
+ summary = state.get("summary", "")
353
  messages = state["messages"]
354
 
355
+ print("#" * 50)
356
  print(">>>>>>>>>> CHAT NODE START <<<<<<<<<<")
357
+ print(f"[SUMMARY]: {summary[:120] if summary else 'None'}")
358
+ for m in messages:
359
+ print(f" [{m.__class__.__name__}]: {str(m.content)[:160]}")
360
+ print("#" * 50)
361
+
362
+ sys_content = SUMMARY_SYSTEM.format(summary=summary) if summary else BASE_SYSTEM
363
+ full_messages = [SystemMessage(content=sys_content)] + list(messages)
364
+
365
+ # Stream tokens from Gemini — LangGraph relays these via
366
+ # stream_mode="messages" before the node returns its state update.
367
+ collected: list[AIMessageChunk] = []
368
+ async for chunk in self.llm_with_tools.astream(full_messages):
369
+ collected.append(chunk)
370
+
371
+ # Merge chunks into a single AIMessage for the state
372
+ if not collected:
373
+ response = AIMessage(content="")
374
  else:
375
+ # LangChain chunk addition merges content + tool_calls correctly
376
+ response = collected[0]
377
+ for c in collected[1:]:
378
+ response = response + c
379
 
380
+ print(f"[AI]: {str(response.content)[:200]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  print(">>>>>>>>>> CHAT NODE END <<<<<<<<<<")
 
382
  return {"messages": [response]}
383
 
384
+ # ── Graph ──────────────────────────────────────────────────────────────────
385
  def _build_graph(self):
386
  g = StateGraph(ChatState)
387
  g.add_node("chat_node", self.chat_node)
388
+ g.add_node("tools", self.tool_node)
 
389
  g.add_edge(START, "chat_node")
390
  g.add_conditional_edges("chat_node", tools_condition)
391
  g.add_edge("tools", "chat_node")
 
392
  return g.compile(checkpointer=self.checkpointer)
393
 
394
  def _build_summary_graph(self):
 
398
  g.add_edge("summarize_node", END)
399
  return g.compile(checkpointer=self.checkpointer)
400
 
401
+ # ── Streaming ──────────────────────────────────────────────────────────────
402
  async def ai_only_stream(self, initial_state: dict, config: dict):
403
+ """
404
+ Async generator yields AI text tokens as they arrive from Gemini.
405
+
406
+ Because chat_node now uses astream() internally, LangGraph's
407
+ stream_mode='messages' receives genuine token chunks from the model
408
+ and re-emits them here — no more full-response buffering.
409
+ """
410
+ async for chunk, _meta in self.graph.astream(
411
+ initial_state, config=config, stream_mode="messages"
412
+ ):
413
+ if isinstance(chunk, AIMessage) and chunk.content:
414
+ yield chunk.content
415
+
416
+ # Auto-summarise in background when history grows long
417
+ current = await self.graph.aget_state(config)
418
+ if len(current.values.get("messages", [])) > 10:
419
  asyncio.create_task(
420
+ self.summary_graph.ainvoke(current.values, config=config)
421
  )
422
+ print("@" * 20, "Summarisation triggered", "@" * 20)
423
 
424
+ # ── Thread management ──────────────────────────────────────────────────────
425
  @staticmethod
426
  def generate_thread_id() -> str:
427
  return str(uuid.uuid4())
428
 
429
+ async def retrieve_all_threads(self) -> list[str]:
430
+ threads: set[str] = set()
431
+ async for cp in self.checkpointer.alist(None):
432
+ threads.add(cp.config["configurable"]["thread_id"])
433
+ return list(threads)
 
434
 
435
+ # ── Public entry point ─────────────────────────────────────────────────────
436
  async def main(self, user_id: str, user_query: str):
437
+ """Return an async generator of AI text tokens."""
438
  async with self.conn.execute(
439
+ "SELECT threadId FROM userid_threadid WHERE userId = ?", (user_id,)
440
  ) as cursor:
441
+ row = await cursor.fetchone()
442
 
443
+ if row is None:
444
  thread_id = user_id + self.generate_thread_id()
445
  await self.conn.execute(
446
  "INSERT INTO userid_threadid (userId, threadId) VALUES (?, ?)",
 
448
  )
449
  await self.conn.commit()
450
  else:
451
+ thread_id = row[0]
452
 
453
  initial_state = {"messages": [HumanMessage(content=user_query)]}
454
  config = {
455
  "configurable": {"thread_id": thread_id},
456
+ "metadata": {"thread_id": thread_id},
457
+ "run_name": "chat_turn",
458
  }
459
  return self.ai_only_stream(initial_state, config)
frontend/index.html CHANGED
@@ -45,3 +45,4 @@
45
  <script src="script.js"></script>
46
  </body>
47
  </html>
 
 
45
  <script src="script.js"></script>
46
  </body>
47
  </html>
48
+
frontend/script.js CHANGED
@@ -1,22 +1,3 @@
1
- /* ─────────────────────────────────────────────────────────────────────────────
2
- script.js — Voice + text chat client
3
-
4
- Fixes applied
5
- ─────────────
6
- 1. DOUBLE-SEND BUG: silenceTimer is now explicitly cleared whenever
7
- isProcessing is set to true, so a timer that was already ticking
8
- can't fire a second stopRecorder() call.
9
- 2. TTS INTERRUPT / BARGE-IN: stopAllAudio() cancels the current
10
- HTMLAudioElement and sends {"type":"cancel"} to the server so the
11
- TTS pipeline also aborts server-side.
12
- 3. MARKDOWN RENDERING: AI bubble uses marked.parse() instead of
13
- textContent so Bangla markdown (bold, lists, headings) renders
14
- correctly in the chat.
15
- 4. VAD barge-in path: if the user starts speaking while TTS is playing
16
- the audio stops immediately, isProcessing resets, and the new
17
- utterance is captured normally.
18
- ───────────────────────────────────────────────────────────────────────────── */
19
-
20
  const chatBox = document.getElementById('chat-box');
21
  const sendBtn = document.getElementById('send-btn');
22
  const textInput = document.getElementById('text-input');
@@ -24,12 +5,10 @@ const micBtn = document.getElementById('mic-btn');
24
 
25
  const userId = 'walid';
26
 
27
- // ── WebSockets ────────────────────────────────────────────────────────────────
28
  const chatSocket = new WebSocket('ws://127.0.0.1:8679/ws/chat');
29
  const voiceSocket = new WebSocket('ws://127.0.0.1:8679/ws/voice');
30
  voiceSocket.binaryType = 'arraybuffer';
31
 
32
- // ── State ─────────────────────────────────────────────────────────────────────
33
  let micStream = null;
34
  let audioContext = null;
35
  let analyser = null;
@@ -39,18 +18,98 @@ let isListening = false;
39
  let isSpeaking = false;
40
  let silenceTimer = null;
41
  let vadInterval = null;
42
- let isProcessing = false; // true while server is processing / TTS playing
43
 
44
  let currentAIMessage = null;
45
- let currentAudio = null; // the HTMLAudioElement currently playing
46
- let playbackChain = Promise.resolve();
47
 
48
- // ── VAD config ────────────────────────────────────────────────────────────────
49
- const SILENCE_THRESHOLD_DB = -45; // dBFS
50
- const SILENCE_TIMEOUT_MS = 3000; // ms of silence before sending utterance
51
  const VAD_POLL_MS = 100;
52
 
53
- // ── Text chat ─────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  sendBtn.onclick = sendTextMessage;
55
  textInput.addEventListener('keydown', (e) => {
56
  if (e.key === 'Enter') sendTextMessage();
@@ -77,7 +136,6 @@ chatSocket.onmessage = (e) => {
77
  chatSocket.onerror = (e) => console.error('Chat WS error:', e);
78
  chatSocket.onclose = () => console.log('Chat WS closed');
79
 
80
- // ── Voice WebSocket events ────────────────────────────────────────────────────
81
  voiceSocket.onopen = () => console.log('[WS] Voice connected');
82
  voiceSocket.onclose = () => {
83
  console.log('[WS] Voice closed');
@@ -86,7 +144,6 @@ voiceSocket.onclose = () => {
86
  voiceSocket.onerror = (e) => console.error('[WS] Voice error:', e);
87
 
88
  voiceSocket.onmessage = (event) => {
89
- // Binary → audio playback
90
  if (event.data instanceof ArrayBuffer) {
91
  enqueueAudio(event.data);
92
  return;
@@ -106,25 +163,22 @@ voiceSocket.onmessage = (event) => {
106
  break;
107
 
108
  case 'llm_token':
109
- // FIX: stream tokens into a div; final markdown render happens on 'end'
110
  if (!currentAIMessage) {
111
  currentAIMessage = appendMessage('', 'ai');
112
  currentAIMessage._raw = '';
113
  }
114
  currentAIMessage._raw += msg.token;
115
- // Live preview: render markdown progressively
116
  currentAIMessage.innerHTML = marked.parse(currentAIMessage._raw);
117
  chatBox.scrollTop = chatBox.scrollHeight;
118
  break;
119
 
120
  case 'end':
121
- // Ensure final markdown render
122
  if (currentAIMessage && currentAIMessage._raw) {
123
  currentAIMessage.innerHTML = marked.parse(currentAIMessage._raw);
124
  }
125
  currentAIMessage = null;
126
- isProcessing = false;
127
- if (isListening) setMicStatus('listening');
128
  break;
129
 
130
  case 'error':
@@ -137,70 +191,18 @@ voiceSocket.onmessage = (event) => {
137
  break;
138
 
139
  default:
140
- console.log('[WS] Unknown msg:', msg.type);
141
  }
142
  };
143
 
144
- // ── Audio playback ─────────────────────────────────────────────────────────────
145
- function enqueueAudio(buffer) {
146
- playbackChain = playbackChain.then(() => playBuffer(buffer));
147
- }
148
-
149
- function playBuffer(buffer) {
150
- return new Promise((resolve) => {
151
- if (isProcessing === false) {
152
- resolve();
153
- return;
154
- } // cancelled mid-chain
155
-
156
- const blob = new Blob([buffer], { type: 'audio/mpeg' });
157
- const url = URL.createObjectURL(blob);
158
- const audio = new Audio(url);
159
- currentAudio = audio;
160
-
161
- const done = () => {
162
- URL.revokeObjectURL(url);
163
- currentAudio = null;
164
- resolve();
165
- };
166
- audio.onended = done;
167
- audio.onerror = () => {
168
- console.warn('[AUDIO] playback error');
169
- done();
170
- };
171
- audio.play().catch(() => done());
172
- });
173
- }
174
-
175
- /**
176
- * Stop all queued and current audio immediately.
177
- * Also sends a cancel signal to the server so TTS generation stops.
178
- */
179
- function stopAllAudio() {
180
- // Replace the chain with an already-resolved promise so queued buffers
181
- // that haven't started yet are silently dropped.
182
- playbackChain = Promise.resolve();
183
-
184
- if (currentAudio) {
185
- currentAudio.pause();
186
- currentAudio.src = '';
187
- currentAudio = null;
188
- }
189
-
190
- // Tell server to abort TTS pipeline
191
- if (voiceSocket.readyState === WebSocket.OPEN) {
192
- voiceSocket.send(JSON.stringify({ type: 'cancel' }));
193
- }
194
- }
195
-
196
- // ── Mic button ────────────────────────────────────────────────────────────────
197
  micBtn.onclick = async () => {
198
  if (!isListening) await startListening();
199
  else stopListening();
200
  };
201
 
202
- // ── Start continuous listening with VAD ───────────────────────────────────────
203
  async function startListening() {
 
 
204
  try {
205
  micStream = await navigator.mediaDevices.getUserMedia({
206
  audio: {
@@ -228,14 +230,12 @@ async function startListening() {
228
  vadInterval = setInterval(vadTick, VAD_POLL_MS);
229
  }
230
 
231
- // ── Stop everything ───────────────────────────────────────────────────────────
232
  function stopListening() {
233
  clearInterval(vadInterval);
234
  clearTimeout(silenceTimer);
235
  vadInterval = silenceTimer = null;
236
 
237
- if (isSpeaking) stopRecorder(true); // discard in-progress utterance
238
-
239
  stopAllAudio();
240
 
241
  micStream?.getTracks().forEach((t) => t.stop());
@@ -246,7 +246,6 @@ function stopListening() {
246
  setMicStatus('off');
247
  }
248
 
249
- // ── VAD polling ───────────────────────────────────────────────────────────────
250
  function vadTick() {
251
  if (!analyser) return;
252
 
@@ -258,19 +257,18 @@ function vadTick() {
258
  const speaking = db > SILENCE_THRESHOLD_DB;
259
 
260
  if (speaking) {
261
- // FIX: barge-in — user started talking while TTS is playing
262
  if (isProcessing) {
263
- console.log('[VAD] Barge-in detected — stopping TTS.');
264
  stopAllAudio();
265
  isProcessing = false;
266
  }
267
 
268
- // FIX: clear any pending silence timer so it can't double-fire
269
  clearTimeout(silenceTimer);
270
  silenceTimer = null;
271
 
272
  if (!isSpeaking) {
273
  isSpeaking = true;
 
274
  startRecorder();
275
  setMicStatus('recording');
276
  }
@@ -280,9 +278,9 @@ function vadTick() {
280
  silenceTimer = null;
281
  isSpeaking = false;
282
 
283
- // FIX: set isProcessing *before* stopping the recorder so that
284
- // if vadTick fires again during onstop it sees the flag and skips.
285
  isProcessing = true;
 
 
286
  stopRecorder(false);
287
  setMicStatus('processing');
288
  }, SILENCE_TIMEOUT_MS);
@@ -290,7 +288,6 @@ function vadTick() {
290
  }
291
  }
292
 
293
- // ── Recorder ──────────────────────────────────────────────────────────────────
294
  function startRecorder() {
295
  if (!micStream) return;
296
  audioChunks = [];
@@ -335,7 +332,6 @@ function stopRecorder(discard = false) {
335
  mediaRecorder = null;
336
  }
337
 
338
- // ── UI helpers ────────────────────────────────────────────────────────────────
339
  function setMicStatus(state) {
340
  const labels = {
341
  off: '🎤 Start Voice',
@@ -352,7 +348,6 @@ function appendMessage(text, sender) {
352
  div.className = `message ${sender}`;
353
 
354
  if (sender === 'ai' && typeof marked !== 'undefined') {
355
- // FIX: render Bangla markdown (bold, lists, headings) properly
356
  div.innerHTML = marked.parse(text);
357
  } else {
358
  div.textContent = text;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  const chatBox = document.getElementById('chat-box');
2
  const sendBtn = document.getElementById('send-btn');
3
  const textInput = document.getElementById('text-input');
 
5
 
6
  const userId = 'walid';
7
 
 
8
  const chatSocket = new WebSocket('ws://127.0.0.1:8679/ws/chat');
9
  const voiceSocket = new WebSocket('ws://127.0.0.1:8679/ws/voice');
10
  voiceSocket.binaryType = 'arraybuffer';
11
 
 
12
  let micStream = null;
13
  let audioContext = null;
14
  let analyser = null;
 
18
  let isSpeaking = false;
19
  let silenceTimer = null;
20
  let vadInterval = null;
21
+ let isProcessing = false;
22
 
23
  let currentAIMessage = null;
24
+ let _playbackCancelled = false;
 
25
 
26
+ const SILENCE_THRESHOLD_DB = -45;
27
+ const SILENCE_TIMEOUT_MS = 1200;
 
28
  const VAD_POLL_MS = 100;
29
 
30
+ let _playCtx = null;
31
+ let _schedEndTime = 0;
32
+ let _endTimer = null;
33
+
34
+ function _getPlayCtx() {
35
+ if (!_playCtx || _playCtx.state === 'closed') {
36
+ _playCtx = new (window.AudioContext || window.webkitAudioContext)();
37
+ _schedEndTime = 0;
38
+ }
39
+
40
+ if (_playCtx.state === 'suspended') _playCtx.resume();
41
+ return _playCtx;
42
+ }
43
+
44
+ async function enqueueAudio(buffer) {
45
+ if (_playbackCancelled) return;
46
+
47
+ const ctx = _getPlayCtx();
48
+ let decoded;
49
+ try {
50
+ decoded = await ctx.decodeAudioData(buffer.slice(0));
51
+ } catch (err) {
52
+ console.warn('[AUDIO] decode error:', err);
53
+ return;
54
+ }
55
+
56
+ if (_playbackCancelled) return;
57
+
58
+ const src = ctx.createBufferSource();
59
+ src.buffer = decoded;
60
+ src.connect(ctx.destination);
61
+
62
+ const now = ctx.currentTime;
63
+ const startAt = Math.max(now + 0.02, _schedEndTime);
64
+ src.start(startAt);
65
+ _schedEndTime = startAt + decoded.duration;
66
+ }
67
+
68
+ /**
69
+ * Called once the server sends `{type:"end"}`.
70
+ * We know all audio is enqueued; schedule the "processing done" callback
71
+ * to fire when the last chunk finishes playing.
72
+ */
73
+ function _schedulePlaybackEnd() {
74
+ clearTimeout(_endTimer);
75
+
76
+ const ctx = _playCtx;
77
+ if (!ctx || ctx.state === 'closed') {
78
+ _onPlaybackFinished();
79
+ return;
80
+ }
81
+
82
+ const remaining = Math.max(0, (_schedEndTime - ctx.currentTime) * 1000) + 120;
83
+ _endTimer = setTimeout(() => {
84
+ if (!_playbackCancelled) _onPlaybackFinished();
85
+ }, remaining);
86
+ }
87
+
88
+ function _onPlaybackFinished() {
89
+ isProcessing = false;
90
+ if (isListening) setMicStatus('listening');
91
+ }
92
+
93
+ /**
94
+ * Stop all queued and currently-playing audio immediately.
95
+ * Closes the AudioContext so future-scheduled nodes are silenced too.
96
+ */
97
+ function stopAllAudio() {
98
+ _playbackCancelled = true;
99
+ clearTimeout(_endTimer);
100
+ _endTimer = null;
101
+
102
+ if (_playCtx && _playCtx.state !== 'closed') {
103
+ _playCtx.close().catch(() => {});
104
+ }
105
+ _playCtx = null;
106
+ _schedEndTime = 0;
107
+
108
+ if (voiceSocket.readyState === WebSocket.OPEN) {
109
+ voiceSocket.send(JSON.stringify({ type: 'cancel' }));
110
+ }
111
+ }
112
+
113
  sendBtn.onclick = sendTextMessage;
114
  textInput.addEventListener('keydown', (e) => {
115
  if (e.key === 'Enter') sendTextMessage();
 
136
  chatSocket.onerror = (e) => console.error('Chat WS error:', e);
137
  chatSocket.onclose = () => console.log('Chat WS closed');
138
 
 
139
  voiceSocket.onopen = () => console.log('[WS] Voice connected');
140
  voiceSocket.onclose = () => {
141
  console.log('[WS] Voice closed');
 
144
  voiceSocket.onerror = (e) => console.error('[WS] Voice error:', e);
145
 
146
  voiceSocket.onmessage = (event) => {
 
147
  if (event.data instanceof ArrayBuffer) {
148
  enqueueAudio(event.data);
149
  return;
 
163
  break;
164
 
165
  case 'llm_token':
 
166
  if (!currentAIMessage) {
167
  currentAIMessage = appendMessage('', 'ai');
168
  currentAIMessage._raw = '';
169
  }
170
  currentAIMessage._raw += msg.token;
 
171
  currentAIMessage.innerHTML = marked.parse(currentAIMessage._raw);
172
  chatBox.scrollTop = chatBox.scrollHeight;
173
  break;
174
 
175
  case 'end':
 
176
  if (currentAIMessage && currentAIMessage._raw) {
177
  currentAIMessage.innerHTML = marked.parse(currentAIMessage._raw);
178
  }
179
  currentAIMessage = null;
180
+
181
+ _schedulePlaybackEnd();
182
  break;
183
 
184
  case 'error':
 
191
  break;
192
 
193
  default:
194
+ console.log('[WS] Unknown msg type:', msg.type);
195
  }
196
  };
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  micBtn.onclick = async () => {
199
  if (!isListening) await startListening();
200
  else stopListening();
201
  };
202
 
 
203
  async function startListening() {
204
+ _getPlayCtx();
205
+
206
  try {
207
  micStream = await navigator.mediaDevices.getUserMedia({
208
  audio: {
 
230
  vadInterval = setInterval(vadTick, VAD_POLL_MS);
231
  }
232
 
 
233
  function stopListening() {
234
  clearInterval(vadInterval);
235
  clearTimeout(silenceTimer);
236
  vadInterval = silenceTimer = null;
237
 
238
+ if (isSpeaking) stopRecorder(true);
 
239
  stopAllAudio();
240
 
241
  micStream?.getTracks().forEach((t) => t.stop());
 
246
  setMicStatus('off');
247
  }
248
 
 
249
  function vadTick() {
250
  if (!analyser) return;
251
 
 
257
  const speaking = db > SILENCE_THRESHOLD_DB;
258
 
259
  if (speaking) {
 
260
  if (isProcessing) {
261
+ console.log('[VAD] Barge-in — stopping TTS.');
262
  stopAllAudio();
263
  isProcessing = false;
264
  }
265
 
 
266
  clearTimeout(silenceTimer);
267
  silenceTimer = null;
268
 
269
  if (!isSpeaking) {
270
  isSpeaking = true;
271
+ _playbackCancelled = false;
272
  startRecorder();
273
  setMicStatus('recording');
274
  }
 
278
  silenceTimer = null;
279
  isSpeaking = false;
280
 
 
 
281
  isProcessing = true;
282
+ _playbackCancelled = false;
283
+
284
  stopRecorder(false);
285
  setMicStatus('processing');
286
  }, SILENCE_TIMEOUT_MS);
 
288
  }
289
  }
290
 
 
291
  function startRecorder() {
292
  if (!micStream) return;
293
  audioChunks = [];
 
332
  mediaRecorder = null;
333
  }
334
 
 
335
  function setMicStatus(state) {
336
  const labels = {
337
  off: '🎤 Start Voice',
 
348
  div.className = `message ${sender}`;
349
 
350
  if (sender === 'ai' && typeof marked !== 'undefined') {
 
351
  div.innerHTML = marked.parse(text);
352
  } else {
353
  div.textContent = text;
services/streaming.py CHANGED
@@ -1,172 +1,237 @@
1
- """
2
- services/streaming.py — Parallel + ordered TTS streamer
3
-
4
- Fixes applied
5
- ─────────────
6
- 1. BUFFER RACE — self.buffer is now only mutated while holding
7
- self._flush_lock, so add_token() and _schedule_flush() can never
8
- interleave partial writes.
9
- 2. CANCELLATION — ParallelTTSStreamer.cancel() drops all pending tasks
10
- and poisons the queue with a sentinel so stream_audio() exits
11
- immediately. app.py calls cancel() when the user starts speaking
12
- mid-playback, giving true barge-in / interrupt behaviour.
13
- 3. Markdown stripping (_clean_for_tts) is unchanged.
14
- 4. Audio ordering guarantee is unchanged (task-chain pattern).
15
- """
16
-
17
  from __future__ import annotations
18
 
19
  import asyncio
20
  import re
 
 
21
 
22
  import edge_tts
23
 
24
- VOICE = "bn-BD-NabanitaNeural"
25
- FLUSH_LEN = 80 # chars before forced flush
26
- MIN_CHARS = 5 # skip tiny fragments
27
- FLUSH_TRIGGERS = frozenset(".!?।,;:\n—–")
 
 
 
 
 
 
 
28
 
29
 
30
- # ── Markdown → plain text ──────────────────────────────────────────────────────
31
  def _clean_for_tts(text: str) -> str:
32
- text = re.sub(r"\*{1,3}", "", text)
33
- text = re.sub(r"#+\s*", "", text)
34
- text = re.sub(r"^\s*[-•]\s*", "", text, flags=re.MULTILINE)
35
  text = re.sub(r"^\s*[\d০-৯]+[.)]\s*", "", text, flags=re.MULTILINE)
36
- text = re.sub(r"`+", "", text)
37
- text = re.sub(r"\n{2,}", "\n", text)
38
  return text.strip()
39
 
40
 
41
- # ── Streamer ───────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class ParallelTTSStreamer:
43
  """
44
- Collects LLM tokens → prosodic chunks → parallel edge-tts calls →
45
- serialised audio queue.
46
 
47
  Usage
48
  ─────
49
  streamer = ParallelTTSStreamer()
50
 
51
- # producer
52
  await streamer.add_token(token)
53
- await streamer.flush() # call once when LLM finishes
54
 
55
- # consumer (run concurrently with producer)
56
- async for chunk in streamer.stream_audio():
57
- await ws.send_bytes(chunk)
58
 
59
- # interrupt (call from any coroutine)
60
  await streamer.cancel()
61
  """
62
 
63
  def __init__(self, voice: str = VOICE) -> None:
64
  self.voice = voice
65
  self.buffer = ""
66
- self.queue: asyncio.Queue[bytes | None] = asyncio.Queue()
67
- self._prev_task: asyncio.Task | None = None
68
- self._flush_lock = asyncio.Lock()
69
  self._cancelled = False
70
- self._tasks: list[asyncio.Task] = [] # track all live tasks
 
 
 
 
 
 
71
 
72
- # ── Token intake ───────────────────────────────────────────────────────────
73
  async def add_token(self, token: str) -> None:
74
  if not token or self._cancelled:
75
  return
76
 
77
- # FIX: hold the lock for the buffer write too, not just the flush
78
- async with self._flush_lock:
79
- self.buffer += token
80
- should_flush = (
81
- any(ch in FLUSH_TRIGGERS for ch in token)
82
- or len(self.buffer) >= FLUSH_LEN
83
- )
84
 
85
- if should_flush:
86
- await self._schedule_flush()
 
87
 
88
- # ── Flush scheduling ───────────────────────────────────────────────────────
89
- async def _schedule_flush(self) -> None:
90
  if self._cancelled:
 
91
  return
92
 
93
- async with self._flush_lock:
94
- raw = self.buffer.strip()
95
- self.buffer = ""
96
 
97
  text = _clean_for_tts(raw)
98
  if len(text) < MIN_CHARS:
99
  return
100
 
101
- prev = self._prev_task
102
- task = asyncio.create_task(self._tts_ordered(text, prev))
103
- self._prev_task = task
 
 
 
 
 
104
  self._tasks.append(task)
105
- task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
106
-
107
- # ── Ordered TTS task ───────────────────────────────────────────────────────
108
- async def _tts_ordered(self, text: str, wait_for: asyncio.Task | None) -> None:
109
- # Step 1 — synthesise (may run in parallel with other chunks)
110
- audio_chunks: list[bytes] = []
111
- if not self._cancelled:
112
- try:
113
- communicate = edge_tts.Communicate(text, self.voice)
114
- async for chunk in communicate.stream():
115
- if self._cancelled:
116
- break
117
- if chunk["type"] == "audio":
118
- audio_chunks.append(chunk["data"])
119
- except Exception as exc:
120
- print(f"[TTS] edge-tts error for '{text[:40]}': {exc}")
121
-
122
- # Step 2 — wait for predecessor to finish queuing (preserves order)
123
- if wait_for and not wait_for.done():
124
- try:
125
- await wait_for
126
- except Exception:
127
- pass
128
-
129
- # Step 3 — write to queue (skipped if cancelled)
130
- if not self._cancelled:
131
- for data in audio_chunks:
132
- await self.queue.put(data)
133
-
134
- # ── Flush remaining buffer ─────────────────────────────────────────────────
135
  async def flush(self) -> None:
136
- """Call once after the LLM stream ends."""
137
- await self._schedule_flush()
138
- if self._prev_task:
139
- try:
140
- await self._prev_task
141
- except Exception:
142
- pass
143
- await self.queue.put(None) # end-of-stream sentinel
144
-
145
- # ── Interrupt / barge-in ───────────────────────────────────────────────────
 
146
  async def cancel(self) -> None:
147
  """
148
- Immediately abort all in-flight TTS tasks and unblock stream_audio().
149
- Safe to call from any coroutine while stream_audio() is running.
 
150
  """
151
  self._cancelled = True
152
 
153
- # Cancel all pending asyncio tasks
154
  for task in list(self._tasks):
155
  task.cancel()
 
 
 
 
 
 
 
 
 
 
156
 
157
- # Drain and poison the queue so stream_audio() exits
158
- while not self.queue.empty():
159
- try:
160
- self.queue.get_nowait()
161
- except asyncio.QueueEmpty:
162
- break
163
- await self.queue.put(None) # sentinel → stream_audio exits
164
 
165
- # ── Audio consumer ─────────────────────────────────────────────────────────
166
  async def stream_audio(self):
167
- """Async generator — yields ordered audio bytes until cancelled/done."""
 
 
 
 
 
 
168
  while True:
169
- chunk = await self.queue.get()
170
- if chunk is None:
171
- break
172
- yield chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import asyncio
4
  import re
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
 
8
  import edge_tts
9
 
10
+ VOICE = "bn-BD-NabanitaNeural"
11
+
12
+
13
+ FIRST_FLUSH_BOUNDARY_MIN = 25
14
+ FIRST_FLUSH_HARD = 70
15
+ SUBSEQUENT_FLUSH_BOUNDARY_MIN = 40
16
+ SUBSEQUENT_FLUSH_HARD = 110
17
+ MIN_CHARS = 4
18
+
19
+ SENTENCE_BOUNDARIES = frozenset(".!?।॥\n")
20
+ CLAUSE_BOUNDARIES = frozenset(",;:—–")
21
 
22
 
 
23
  def _clean_for_tts(text: str) -> str:
24
+ text = re.sub(r"\*{1,3}", "", text)
25
+ text = re.sub(r"#+\s*", "", text)
26
+ text = re.sub(r"^\s*[-•]\s*", "", text, flags=re.MULTILINE)
27
  text = re.sub(r"^\s*[\d০-৯]+[.)]\s*", "", text, flags=re.MULTILINE)
28
+ text = re.sub(r"`+", "", text)
29
+ text = re.sub(r"\n{2,}", "\n", text)
30
  return text.strip()
31
 
32
 
33
+ def _should_flush(buffer: str, first_chunk: bool) -> bool:
34
+ """
35
+ Return True if the buffer is ready to be sent to TTS.
36
+
37
+ Flushing strategy (per chunk):
38
+ 1. If we hit a sentence boundary and have enough chars → flush.
39
+ 2. If we're at the hard limit (even mid-sentence) → flush.
40
+ 3. If we hit a clause boundary near the hard limit → flush early.
41
+ """
42
+ n = len(buffer)
43
+ boundary_min = FIRST_FLUSH_BOUNDARY_MIN if first_chunk else SUBSEQUENT_FLUSH_BOUNDARY_MIN
44
+ hard_limit = FIRST_FLUSH_HARD if first_chunk else SUBSEQUENT_FLUSH_HARD
45
+
46
+ if n == 0:
47
+ return False
48
+ if n >= hard_limit:
49
+ return True
50
+
51
+ last_char = buffer[-1] if buffer else ""
52
+ if last_char in SENTENCE_BOUNDARIES and n >= boundary_min:
53
+ return True
54
+ if last_char in CLAUSE_BOUNDARIES and n >= hard_limit * 0.8:
55
+ return True
56
+
57
+ return False
58
+
59
+
60
+ @dataclass
61
+ class _AudioSlot:
62
+ """Holds synthesised audio for one TTS chunk. Delivered in slot order."""
63
+ index: int
64
+ ready: asyncio.Event = field(default_factory=asyncio.Event)
65
+ chunks: list[bytes] = field(default_factory=list)
66
+ error: bool = False
67
+
68
+
69
  class ParallelTTSStreamer:
70
  """
71
+ Collects LLM tokens → prosodic sentence chunks → parallel edge-tts
72
+ synthesis → slot-ordered audio delivery.
73
 
74
  Usage
75
  ─────
76
  streamer = ParallelTTSStreamer()
77
 
78
+
79
  await streamer.add_token(token)
80
+ await streamer.flush()
81
 
82
+
83
+ async for audio_bytes in streamer.stream_audio():
84
+ await ws.send_bytes(audio_bytes)
85
 
86
+
87
  await streamer.cancel()
88
  """
89
 
90
  def __init__(self, voice: str = VOICE) -> None:
91
  self.voice = voice
92
  self.buffer = ""
 
 
 
93
  self._cancelled = False
94
+ self._first_chunk = True
95
+ self._slot_index = 0
96
+ self._slots: list[_AudioSlot] = []
97
+ self._slots_lock = asyncio.Lock()
98
+ self._tasks: list[asyncio.Task] = []
99
+ self._done_event = asyncio.Event()
100
+
101
 
 
102
  async def add_token(self, token: str) -> None:
103
  if not token or self._cancelled:
104
  return
105
 
106
+ self.buffer += token
 
 
 
 
 
 
107
 
108
+ if _should_flush(self.buffer, self._first_chunk):
109
+ self._first_chunk = False
110
+ await self._schedule_chunk()
111
 
112
+
113
+ async def _schedule_chunk(self) -> None:
114
  if self._cancelled:
115
+ self.buffer = ""
116
  return
117
 
118
+ raw = self.buffer.strip()
119
+ self.buffer = ""
 
120
 
121
  text = _clean_for_tts(raw)
122
  if len(text) < MIN_CHARS:
123
  return
124
 
125
+
126
+
127
+ async with self._slots_lock:
128
+ slot = _AudioSlot(index=self._slot_index)
129
+ self._slot_index += 1
130
+ self._slots.append(slot)
131
+
132
+ task = asyncio.create_task(self._synthesise(text, slot))
133
  self._tasks.append(task)
134
+ task.add_done_callback(
135
+ lambda t: self._tasks.remove(t) if t in self._tasks else None
136
+ )
137
+
138
+
139
+ async def _synthesise(self, text: str, slot: _AudioSlot) -> None:
140
+ if self._cancelled:
141
+ slot.error = True
142
+ slot.ready.set()
143
+ return
144
+
145
+ try:
146
+ communicate = edge_tts.Communicate(text, self.voice)
147
+ async for chunk in communicate.stream():
148
+ if self._cancelled:
149
+ slot.error = True
150
+ slot.ready.set()
151
+ return
152
+ if chunk["type"] == "audio":
153
+ slot.chunks.append(chunk["data"])
154
+ except asyncio.CancelledError:
155
+ slot.error = True
156
+ except Exception as exc:
157
+ print(f"[TTS] edge-tts error for '{text[:50]}': {exc}")
158
+ slot.error = True
159
+ finally:
160
+ slot.ready.set()
161
+
162
+
 
163
  async def flush(self) -> None:
164
+
165
+ if self.buffer.strip():
166
+ await self._schedule_chunk()
167
+
168
+
169
+ if self._tasks:
170
+ await asyncio.gather(*self._tasks, return_exceptions=True)
171
+
172
+ self._done_event.set()
173
+
174
+
175
  async def cancel(self) -> None:
176
  """
177
+ Immediately abort all in-flight synthesis tasks.
178
+ Marks all pending slots as errored so stream_audio() exits promptly.
179
+ Idempotent.
180
  """
181
  self._cancelled = True
182
 
 
183
  for task in list(self._tasks):
184
  task.cancel()
185
+ self._tasks.clear()
186
+
187
+
188
+ async with self._slots_lock:
189
+ for slot in self._slots:
190
+ if not slot.ready.is_set():
191
+ slot.error = True
192
+ slot.ready.set()
193
+
194
+ self._done_event.set()
195
 
 
 
 
 
 
 
 
196
 
 
197
  async def stream_audio(self):
198
+ """
199
+ Yields ordered audio bytes. Slots are consumed in creation order;
200
+ each slot is awaited individually so synthesis of slot N+1 can
201
+ proceed in parallel while the consumer is yielding slot N's bytes.
202
+ """
203
+ delivered = 0
204
+
205
  while True:
206
+
207
+ async with self._slots_lock:
208
+ if delivered < len(self._slots):
209
+ slot = self._slots[delivered]
210
+ else:
211
+ slot = None
212
+
213
+ if slot is None:
214
+
215
+ if self._done_event.is_set():
216
+ break
217
+ await asyncio.sleep(0.005)
218
+ continue
219
+
220
+
221
+ await slot.ready.wait()
222
+
223
+ if not self._cancelled and not slot.error:
224
+ for audio_bytes in slot.chunks:
225
+ yield audio_bytes
226
+
227
+ delivered += 1
228
+
229
+
230
+ def reset(self) -> None:
231
+ self._cancelled = False
232
+ self._first_chunk = True
233
+ self.buffer = ""
234
+ self._slot_index = 0
235
+ self._slots.clear()
236
+ self._tasks.clear()
237
+ self._done_event.clear()