""" services/webrtc_pipeline.py — WebRTC Audio Pipeline + Full Parallelization FIX-BUG3 (AudioFrameReceiver never driven): In the original code, AudioFrameReceiver was instantiated but its recv() method was never called. aiortc only delivers frames when a consumer calls recv() in a loop. Without this, the frame queue was always empty → no audio reached the VAD → no utterances → zero voice responses via WebRTC. Fix: spawn a coroutine (_recv_loop) that calls receiver.recv() continuously. All other logic preserved. """ from __future__ import annotations import asyncio import json import uuid from typing import Optional try: from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack from aiortc.contrib.media import MediaBlackhole import av AIORTC_AVAILABLE = True except ImportError: AIORTC_AVAILABLE = False print("[WebRTC] aiortc not installed — WebRTC pipeline unavailable. " "Install: pip install aiortc") try: import webrtcvad VAD_AVAILABLE = True except ImportError: VAD_AVAILABLE = False print("[WebRTC] webrtcvad not installed — VAD unavailable.") from services.stt import STTProcessor from services.streaming import ParallelTTSStreamer # ══════════════════════════════════════════════════════════════════════════════ # VAD SEGMENTER (PCM frames → speech utterances) # ══════════════════════════════════════════════════════════════════════════════ class _VADSegmenter: """ Accumulates raw 16-bit mono PCM frames. Yields complete utterances when silence follows speech. """ def __init__( self, sample_rate: int = 16_000, frame_ms: int = 20, # 20ms frames — aiortc default aggressiveness: int = 2, silence_limit: int = 12, # ~240ms silence → end of utterance ) -> None: self.sample_rate = sample_rate self.frame_bytes = int(sample_rate * frame_ms / 1000) * 2 # 16-bit samples self.silence_limit = silence_limit self._vad = webrtcvad.Vad(aggressiveness) if VAD_AVAILABLE else None self._buffer = bytearray() self._silence_count = 0 self._active = False def process_frame(self, pcm_frame: bytes) -> Optional[bytes]: """ Feed one 20ms PCM frame. Returns a complete utterance bytes object when speech ends, else None. """ if self._vad is None: # No VAD — buffer everything, flush after 3s self._buffer.extend(pcm_frame) if len(self._buffer) >= self.sample_rate * 3 * 2: data = bytes(self._buffer) self._buffer.clear() return data return None # Pad or trim to exact frame size frame = pcm_frame[:self.frame_bytes].ljust(self.frame_bytes, b'\x00') try: is_speech = self._vad.is_speech(frame, self.sample_rate) except Exception: is_speech = False if is_speech: self._buffer.extend(frame) self._active = True self._silence_count = 0 elif self._active: self._buffer.extend(frame) self._silence_count += 1 if self._active and self._silence_count >= self.silence_limit: data = bytes(self._buffer) self._buffer.clear() self._silence_count = 0 self._active = False return data return None # ══════════════════════════════════════════════════════════════════════════════ # AUDIO TRACK RECEIVER # ══════════════════════════════════════════════════════════════════════════════ if AIORTC_AVAILABLE: class AudioFrameReceiver(MediaStreamTrack): """ Wraps an incoming WebRTC audio track. Resamples to 16kHz mono PCM and pushes frames into an asyncio.Queue. IMPORTANT: call start_receiving() after construction to begin consuming frames from the underlying track via recv(). """ kind = "audio" def __init__(self, track: MediaStreamTrack, frame_queue: asyncio.Queue) -> None: super().__init__() self._track = track self._frame_queue = frame_queue self._resampler: Optional[av.AudioResampler] = None self._recv_task: Optional[asyncio.Task] = None def start_receiving(self) -> None: """ FIX-BUG3: Spawn the recv() loop so the track actually delivers frames. Without this, _frame_queue stays empty forever. """ if self._recv_task is None or self._recv_task.done(): self._recv_task = asyncio.ensure_future(self._recv_loop()) async def _recv_loop(self) -> None: """Continuously consume frames from the remote track.""" while True: try: frame = await self._track.recv() except Exception as exc: print(f"[WebRTC] AudioFrameReceiver: track ended ({exc})") break if self._resampler is None: self._resampler = av.AudioResampler( format="s16", layout="mono", rate=16_000, ) try: resampled = self._resampler.resample(frame) for rf in resampled: pcm = bytes(rf.planes[0]) try: self._frame_queue.put_nowait(pcm) except asyncio.QueueFull: pass # Drop frame under backpressure — prefer real-time except Exception as exc: print(f"[WebRTC] Resample error: {exc}") async def recv(self): """Required override — delegates to the underlying track.""" return await self._track.recv() def stop_receiving(self) -> None: if self._recv_task and not self._recv_task.done(): self._recv_task.cancel() # ══════════════════════════════════════════════════════════════════════════════ # TURN PIPELINE (STT → LLM → TTS, all parallel) # ══════════════════════════════════════════════════════════════════════════════ class _TurnPipeline: """ Runs one conversation turn: speech bytes → transcript → LLM stream → audio. Designed to be created fresh per turn (or cancelled on barge-in). """ def __init__(self, ai_backend, data_channel, on_stt=None, on_token=None): self._ai = ai_backend self._channel = data_channel # RTCDataChannel for audio delivery self._on_stt = on_stt # optional callback(str) self._on_token = on_token # optional callback(str) self._stt = STTProcessor() self._streamer = ParallelTTSStreamer() self._cancelled = False self._tasks: list[asyncio.Task] = [] async def run(self, user_id: str, audio_bytes: bytes) -> None: """Full pipeline: audio → STT → LLM+TTS (parallel).""" # ── Phase 1: STT (GPU-batched) ──────────────────────────────────────── transcript = await self._stt.transcribe(audio_bytes) if not transcript or self._cancelled: self._send_ctrl({"type": "end"}) return if self._on_stt: self._on_stt(transcript) self._send_ctrl({"type": "stt", "text": transcript}) # ── Phase 2: LLM + TTS in parallel ─────────────────────────────────── await asyncio.gather( self._run_llm(user_id, transcript), self._run_tts_delivery(), return_exceptions=True, ) if not self._cancelled: self._send_ctrl({"type": "end"}) async def _run_llm(self, user_id: str, transcript: str) -> None: """Stream LLM tokens → TTS streamer (concurrent with audio delivery).""" full_text = "" try: stream = await self._ai.main(user_id, transcript) async for token in stream: if self._cancelled or not token: break full_text += token if self._on_token: self._on_token(token) self._send_ctrl({"type": "llm_token", "token": token}) except asyncio.CancelledError: raise except Exception as exc: print(f"[Pipeline] LLM error: {exc}") finally: # Feed the completed response to TTS for more reliable synthesis. if full_text and not self._cancelled: await self._streamer.add_token(full_text) await self._streamer.flush() async def _run_tts_delivery(self) -> None: """Forward audio chunks from TTS streamer to WebRTC data channel.""" async for chunk in self._streamer.stream_audio(): if self._cancelled: break self._send_audio(chunk) def _send_ctrl(self, payload: dict) -> None: if self._channel and self._channel.readyState == "open": try: self._channel.send(json.dumps(payload)) except Exception: pass def _send_audio(self, data: bytes) -> None: if self._channel and self._channel.readyState == "open": try: self._channel.send(data) except Exception: pass async def cancel(self) -> None: self._cancelled = True await self._streamer.cancel() for t in self._tasks: t.cancel() if self._tasks: await asyncio.gather(*self._tasks, return_exceptions=True) # ══════════════════════════════════════════════════════════════════════════════ # WEBRTC SESSION HANDLER # ══════════════════════════════════════════════════════════════════════════════ class WebRTCSession: """ Manages one WebRTC peer connection: • Handles ICE/SDP negotiation • Receives audio track → VAD → STT queue • Sends responses back via RTCDataChannel • Supports barge-in (cancel active turn on new speech) """ def __init__(self, ai_backend) -> None: if not AIORTC_AVAILABLE: raise RuntimeError("aiortc is required for WebRTC mode") self._ai = ai_backend self.user_id = f"rtc_{uuid.uuid4().hex[:12]}" self._pc = RTCPeerConnection() self._channel = None self._frame_q: asyncio.Queue = asyncio.Queue(maxsize=500) self._vad = _VADSegmenter() self._active_turn: Optional[_TurnPipeline] = None self._active_task: Optional[asyncio.Task] = None # Keep references to receivers so they are not garbage-collected self._receivers: list[AudioFrameReceiver] = [] self._setup_pc() def _setup_pc(self) -> None: pc = self._pc @pc.on("track") def on_track(track): if track.kind == "audio": # FIX-BUG3: create receiver AND start its recv() loop receiver = AudioFrameReceiver(track, self._frame_q) receiver.start_receiving() self._receivers.append(receiver) # prevent GC asyncio.ensure_future(self._frame_processor()) print(f"[WebRTC] Audio track received — receiver started ✓") @pc.on("datachannel") def on_datachannel(channel): self._channel = channel print(f"[WebRTC] DataChannel open: {channel.label}") @channel.on("message") def on_message(msg): try: data = json.loads(msg) if data.get("type") == "cancel": asyncio.ensure_future(self._cancel_active()) elif data.get("type") == "init" and data.get("user_id"): self.user_id = str(data["user_id"])[:64] except Exception: pass @pc.on("connectionstatechange") async def on_state(): print(f"[WebRTC] Connection state: {pc.connectionState}") if pc.connectionState in ("failed", "closed"): await self._cancel_active() async def _frame_processor(self) -> None: """Consume PCM frames from queue → VAD → dispatch turns.""" while True: try: frame = await asyncio.wait_for(self._frame_q.get(), timeout=5.0) except asyncio.TimeoutError: continue except Exception: break utterance = self._vad.process_frame(frame) if utterance: await self._dispatch_turn(utterance) async def _dispatch_turn(self, audio_bytes: bytes) -> None: """Barge-in aware: cancel current turn, start new one.""" await self._cancel_active() pipeline = _TurnPipeline( ai_backend=self._ai, data_channel=self._channel, ) self._active_turn = pipeline self._active_task = asyncio.create_task( pipeline.run(self.user_id, audio_bytes) ) async def _cancel_active(self) -> None: if self._active_turn: await self._active_turn.cancel() self._active_turn = None if self._active_task and not self._active_task.done(): self._active_task.cancel() try: await self._active_task except (asyncio.CancelledError, Exception): pass self._active_task = None async def handle_offer(self, sdp: str, sdp_type: str) -> dict: """Process SDP offer from browser. Returns SDP answer.""" offer = RTCSessionDescription(sdp=sdp, type=sdp_type) await self._pc.setRemoteDescription(offer) answer = await self._pc.createAnswer() await self._pc.setLocalDescription(answer) return { "sdp": self._pc.localDescription.sdp, "type": self._pc.localDescription.type, } async def add_ice_candidate(self, candidate: dict) -> None: """Forward browser ICE candidate to aiortc.""" from aiortc import RTCIceCandidate c = RTCIceCandidate( component=candidate.get("component", 1), foundation=candidate.get("foundation", ""), ip=candidate.get("ip", ""), port=candidate.get("port", 0), priority=candidate.get("priority", 0), protocol=candidate.get("protocol", "udp"), type=candidate.get("type", "host"), sdpMid=candidate.get("sdpMid"), sdpMLineIndex=candidate.get("sdpMLineIndex"), ) await self._pc.addIceCandidate(c) async def close(self) -> None: for receiver in self._receivers: receiver.stop_receiving() self._receivers.clear() await self._cancel_active() await self._pc.close()