File size: 16,567 Bytes
f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 77a79ae f84481c 77a79ae f84481c 77a79ae f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c 5dabf9d f84481c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 | """
services/webrtc_pipeline.py β WebRTC Audio Pipeline + Full Parallelization
FIX-BUG3 (AudioFrameReceiver never driven):
In the original code, AudioFrameReceiver was instantiated but its recv()
method was never called. aiortc only delivers frames when a consumer calls
recv() in a loop. Without this, the frame queue was always empty β no audio
reached the VAD β no utterances β zero voice responses via WebRTC.
Fix: spawn a coroutine (_recv_loop) that calls receiver.recv() continuously.
All other logic preserved.
"""
from __future__ import annotations
import asyncio
import json
import uuid
from typing import Optional
try:
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
from aiortc.contrib.media import MediaBlackhole
import av
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
print("[WebRTC] aiortc not installed β WebRTC pipeline unavailable. "
"Install: pip install aiortc")
try:
import webrtcvad
VAD_AVAILABLE = True
except ImportError:
VAD_AVAILABLE = False
print("[WebRTC] webrtcvad not installed β VAD unavailable.")
from services.stt import STTProcessor
from services.streaming import ParallelTTSStreamer
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# VAD SEGMENTER (PCM frames β speech utterances)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class _VADSegmenter:
"""
Accumulates raw 16-bit mono PCM frames.
Yields complete utterances when silence follows speech.
"""
def __init__(
self,
sample_rate: int = 16_000,
frame_ms: int = 20, # 20ms frames β aiortc default
aggressiveness: int = 2,
silence_limit: int = 12, # ~240ms silence β end of utterance
) -> None:
self.sample_rate = sample_rate
self.frame_bytes = int(sample_rate * frame_ms / 1000) * 2 # 16-bit samples
self.silence_limit = silence_limit
self._vad = webrtcvad.Vad(aggressiveness) if VAD_AVAILABLE else None
self._buffer = bytearray()
self._silence_count = 0
self._active = False
def process_frame(self, pcm_frame: bytes) -> Optional[bytes]:
"""
Feed one 20ms PCM frame.
Returns a complete utterance bytes object when speech ends, else None.
"""
if self._vad is None:
# No VAD β buffer everything, flush after 3s
self._buffer.extend(pcm_frame)
if len(self._buffer) >= self.sample_rate * 3 * 2:
data = bytes(self._buffer)
self._buffer.clear()
return data
return None
# Pad or trim to exact frame size
frame = pcm_frame[:self.frame_bytes].ljust(self.frame_bytes, b'\x00')
try:
is_speech = self._vad.is_speech(frame, self.sample_rate)
except Exception:
is_speech = False
if is_speech:
self._buffer.extend(frame)
self._active = True
self._silence_count = 0
elif self._active:
self._buffer.extend(frame)
self._silence_count += 1
if self._active and self._silence_count >= self.silence_limit:
data = bytes(self._buffer)
self._buffer.clear()
self._silence_count = 0
self._active = False
return data
return None
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# AUDIO TRACK RECEIVER
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if AIORTC_AVAILABLE:
class AudioFrameReceiver(MediaStreamTrack):
"""
Wraps an incoming WebRTC audio track.
Resamples to 16kHz mono PCM and pushes frames into an asyncio.Queue.
IMPORTANT: call start_receiving() after construction to begin
consuming frames from the underlying track via recv().
"""
kind = "audio"
def __init__(self, track: MediaStreamTrack, frame_queue: asyncio.Queue) -> None:
super().__init__()
self._track = track
self._frame_queue = frame_queue
self._resampler: Optional[av.AudioResampler] = None
self._recv_task: Optional[asyncio.Task] = None
def start_receiving(self) -> None:
"""
FIX-BUG3: Spawn the recv() loop so the track actually delivers frames.
Without this, _frame_queue stays empty forever.
"""
if self._recv_task is None or self._recv_task.done():
self._recv_task = asyncio.ensure_future(self._recv_loop())
async def _recv_loop(self) -> None:
"""Continuously consume frames from the remote track."""
while True:
try:
frame = await self._track.recv()
except Exception as exc:
print(f"[WebRTC] AudioFrameReceiver: track ended ({exc})")
break
if self._resampler is None:
self._resampler = av.AudioResampler(
format="s16",
layout="mono",
rate=16_000,
)
try:
resampled = self._resampler.resample(frame)
for rf in resampled:
pcm = bytes(rf.planes[0])
try:
self._frame_queue.put_nowait(pcm)
except asyncio.QueueFull:
pass # Drop frame under backpressure β prefer real-time
except Exception as exc:
print(f"[WebRTC] Resample error: {exc}")
async def recv(self):
"""Required override β delegates to the underlying track."""
return await self._track.recv()
def stop_receiving(self) -> None:
if self._recv_task and not self._recv_task.done():
self._recv_task.cancel()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# TURN PIPELINE (STT β LLM β TTS, all parallel)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class _TurnPipeline:
"""
Runs one conversation turn: speech bytes β transcript β LLM stream β audio.
Designed to be created fresh per turn (or cancelled on barge-in).
"""
def __init__(self, ai_backend, data_channel, on_stt=None, on_token=None):
self._ai = ai_backend
self._channel = data_channel # RTCDataChannel for audio delivery
self._on_stt = on_stt # optional callback(str)
self._on_token = on_token # optional callback(str)
self._stt = STTProcessor()
self._streamer = ParallelTTSStreamer()
self._cancelled = False
self._tasks: list[asyncio.Task] = []
async def run(self, user_id: str, audio_bytes: bytes) -> None:
"""Full pipeline: audio β STT β LLM+TTS (parallel)."""
# ββ Phase 1: STT (GPU-batched) ββββββββββββββββββββββββββββββββββββββββ
transcript = await self._stt.transcribe(audio_bytes)
if not transcript or self._cancelled:
self._send_ctrl({"type": "end"})
return
if self._on_stt:
self._on_stt(transcript)
self._send_ctrl({"type": "stt", "text": transcript})
# ββ Phase 2: LLM + TTS in parallel βββββββββββββββββββββββββββββββββββ
await asyncio.gather(
self._run_llm(user_id, transcript),
self._run_tts_delivery(),
return_exceptions=True,
)
if not self._cancelled:
self._send_ctrl({"type": "end"})
async def _run_llm(self, user_id: str, transcript: str) -> None:
"""Stream LLM tokens β TTS streamer (concurrent with audio delivery)."""
full_text = ""
try:
stream = await self._ai.main(user_id, transcript)
async for token in stream:
if self._cancelled or not token:
break
full_text += token
if self._on_token:
self._on_token(token)
self._send_ctrl({"type": "llm_token", "token": token})
except asyncio.CancelledError:
raise
except Exception as exc:
print(f"[Pipeline] LLM error: {exc}")
finally:
# Feed the completed response to TTS for more reliable synthesis.
if full_text and not self._cancelled:
await self._streamer.add_token(full_text)
await self._streamer.flush()
async def _run_tts_delivery(self) -> None:
"""Forward audio chunks from TTS streamer to WebRTC data channel."""
async for chunk in self._streamer.stream_audio():
if self._cancelled:
break
self._send_audio(chunk)
def _send_ctrl(self, payload: dict) -> None:
if self._channel and self._channel.readyState == "open":
try:
self._channel.send(json.dumps(payload))
except Exception:
pass
def _send_audio(self, data: bytes) -> None:
if self._channel and self._channel.readyState == "open":
try:
self._channel.send(data)
except Exception:
pass
async def cancel(self) -> None:
self._cancelled = True
await self._streamer.cancel()
for t in self._tasks:
t.cancel()
if self._tasks:
await asyncio.gather(*self._tasks, return_exceptions=True)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# WEBRTC SESSION HANDLER
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class WebRTCSession:
"""
Manages one WebRTC peer connection:
β’ Handles ICE/SDP negotiation
β’ Receives audio track β VAD β STT queue
β’ Sends responses back via RTCDataChannel
β’ Supports barge-in (cancel active turn on new speech)
"""
def __init__(self, ai_backend) -> None:
if not AIORTC_AVAILABLE:
raise RuntimeError("aiortc is required for WebRTC mode")
self._ai = ai_backend
self.user_id = f"rtc_{uuid.uuid4().hex[:12]}"
self._pc = RTCPeerConnection()
self._channel = None
self._frame_q: asyncio.Queue = asyncio.Queue(maxsize=500)
self._vad = _VADSegmenter()
self._active_turn: Optional[_TurnPipeline] = None
self._active_task: Optional[asyncio.Task] = None
# Keep references to receivers so they are not garbage-collected
self._receivers: list[AudioFrameReceiver] = []
self._setup_pc()
def _setup_pc(self) -> None:
pc = self._pc
@pc.on("track")
def on_track(track):
if track.kind == "audio":
# FIX-BUG3: create receiver AND start its recv() loop
receiver = AudioFrameReceiver(track, self._frame_q)
receiver.start_receiving()
self._receivers.append(receiver) # prevent GC
asyncio.ensure_future(self._frame_processor())
print(f"[WebRTC] Audio track received β receiver started β")
@pc.on("datachannel")
def on_datachannel(channel):
self._channel = channel
print(f"[WebRTC] DataChannel open: {channel.label}")
@channel.on("message")
def on_message(msg):
try:
data = json.loads(msg)
if data.get("type") == "cancel":
asyncio.ensure_future(self._cancel_active())
elif data.get("type") == "init" and data.get("user_id"):
self.user_id = str(data["user_id"])[:64]
except Exception:
pass
@pc.on("connectionstatechange")
async def on_state():
print(f"[WebRTC] Connection state: {pc.connectionState}")
if pc.connectionState in ("failed", "closed"):
await self._cancel_active()
async def _frame_processor(self) -> None:
"""Consume PCM frames from queue β VAD β dispatch turns."""
while True:
try:
frame = await asyncio.wait_for(self._frame_q.get(), timeout=5.0)
except asyncio.TimeoutError:
continue
except Exception:
break
utterance = self._vad.process_frame(frame)
if utterance:
await self._dispatch_turn(utterance)
async def _dispatch_turn(self, audio_bytes: bytes) -> None:
"""Barge-in aware: cancel current turn, start new one."""
await self._cancel_active()
pipeline = _TurnPipeline(
ai_backend=self._ai,
data_channel=self._channel,
)
self._active_turn = pipeline
self._active_task = asyncio.create_task(
pipeline.run(self.user_id, audio_bytes)
)
async def _cancel_active(self) -> None:
if self._active_turn:
await self._active_turn.cancel()
self._active_turn = None
if self._active_task and not self._active_task.done():
self._active_task.cancel()
try:
await self._active_task
except (asyncio.CancelledError, Exception):
pass
self._active_task = None
async def handle_offer(self, sdp: str, sdp_type: str) -> dict:
"""Process SDP offer from browser. Returns SDP answer."""
offer = RTCSessionDescription(sdp=sdp, type=sdp_type)
await self._pc.setRemoteDescription(offer)
answer = await self._pc.createAnswer()
await self._pc.setLocalDescription(answer)
return {
"sdp": self._pc.localDescription.sdp,
"type": self._pc.localDescription.type,
}
async def add_ice_candidate(self, candidate: dict) -> None:
"""Forward browser ICE candidate to aiortc."""
from aiortc import RTCIceCandidate
c = RTCIceCandidate(
component=candidate.get("component", 1),
foundation=candidate.get("foundation", ""),
ip=candidate.get("ip", ""),
port=candidate.get("port", 0),
priority=candidate.get("priority", 0),
protocol=candidate.get("protocol", "udp"),
type=candidate.get("type", "host"),
sdpMid=candidate.get("sdpMid"),
sdpMLineIndex=candidate.get("sdpMLineIndex"),
)
await self._pc.addIceCandidate(c)
async def close(self) -> None:
for receiver in self._receivers:
receiver.stop_receiving()
self._receivers.clear()
await self._cancel_active()
await self._pc.close()
|