| """ |
| services/webrtc_pipeline.py β WebRTC Audio Pipeline + Full Parallelization |
| |
| FIX-BUG3 (AudioFrameReceiver never driven): |
| In the original code, AudioFrameReceiver was instantiated but its recv() |
| method was never called. aiortc only delivers frames when a consumer calls |
| recv() in a loop. Without this, the frame queue was always empty β no audio |
| reached the VAD β no utterances β zero voice responses via WebRTC. |
| |
| Fix: spawn a coroutine (_recv_loop) that calls receiver.recv() continuously. |
| |
| All other logic preserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import uuid |
| from typing import Optional |
|
|
| try: |
| from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack |
| from aiortc.contrib.media import MediaBlackhole |
| import av |
| AIORTC_AVAILABLE = True |
| except ImportError: |
| AIORTC_AVAILABLE = False |
| print("[WebRTC] aiortc not installed β WebRTC pipeline unavailable. " |
| "Install: pip install aiortc") |
|
|
| try: |
| import webrtcvad |
| VAD_AVAILABLE = True |
| except ImportError: |
| VAD_AVAILABLE = False |
| print("[WebRTC] webrtcvad not installed β VAD unavailable.") |
|
|
| from services.stt import STTProcessor |
| from services.streaming import ParallelTTSStreamer |
|
|
|
|
| |
| |
| |
|
|
| class _VADSegmenter: |
| """ |
| Accumulates raw 16-bit mono PCM frames. |
| Yields complete utterances when silence follows speech. |
| """ |
|
|
| def __init__( |
| self, |
| sample_rate: int = 16_000, |
| frame_ms: int = 20, |
| aggressiveness: int = 2, |
| silence_limit: int = 12, |
| ) -> None: |
| self.sample_rate = sample_rate |
| self.frame_bytes = int(sample_rate * frame_ms / 1000) * 2 |
| self.silence_limit = silence_limit |
| self._vad = webrtcvad.Vad(aggressiveness) if VAD_AVAILABLE else None |
| self._buffer = bytearray() |
| self._silence_count = 0 |
| self._active = False |
|
|
| def process_frame(self, pcm_frame: bytes) -> Optional[bytes]: |
| """ |
| Feed one 20ms PCM frame. |
| Returns a complete utterance bytes object when speech ends, else None. |
| """ |
| if self._vad is None: |
| |
| self._buffer.extend(pcm_frame) |
| if len(self._buffer) >= self.sample_rate * 3 * 2: |
| data = bytes(self._buffer) |
| self._buffer.clear() |
| return data |
| return None |
|
|
| |
| frame = pcm_frame[:self.frame_bytes].ljust(self.frame_bytes, b'\x00') |
|
|
| try: |
| is_speech = self._vad.is_speech(frame, self.sample_rate) |
| except Exception: |
| is_speech = False |
|
|
| if is_speech: |
| self._buffer.extend(frame) |
| self._active = True |
| self._silence_count = 0 |
| elif self._active: |
| self._buffer.extend(frame) |
| self._silence_count += 1 |
|
|
| if self._active and self._silence_count >= self.silence_limit: |
| data = bytes(self._buffer) |
| self._buffer.clear() |
| self._silence_count = 0 |
| self._active = False |
| return data |
|
|
| return None |
|
|
|
|
| |
| |
| |
|
|
| if AIORTC_AVAILABLE: |
| class AudioFrameReceiver(MediaStreamTrack): |
| """ |
| Wraps an incoming WebRTC audio track. |
| Resamples to 16kHz mono PCM and pushes frames into an asyncio.Queue. |
| |
| IMPORTANT: call start_receiving() after construction to begin |
| consuming frames from the underlying track via recv(). |
| """ |
|
|
| kind = "audio" |
|
|
| def __init__(self, track: MediaStreamTrack, frame_queue: asyncio.Queue) -> None: |
| super().__init__() |
| self._track = track |
| self._frame_queue = frame_queue |
| self._resampler: Optional[av.AudioResampler] = None |
| self._recv_task: Optional[asyncio.Task] = None |
|
|
| def start_receiving(self) -> None: |
| """ |
| FIX-BUG3: Spawn the recv() loop so the track actually delivers frames. |
| Without this, _frame_queue stays empty forever. |
| """ |
| if self._recv_task is None or self._recv_task.done(): |
| self._recv_task = asyncio.ensure_future(self._recv_loop()) |
|
|
| async def _recv_loop(self) -> None: |
| """Continuously consume frames from the remote track.""" |
| while True: |
| try: |
| frame = await self._track.recv() |
| except Exception as exc: |
| print(f"[WebRTC] AudioFrameReceiver: track ended ({exc})") |
| break |
|
|
| if self._resampler is None: |
| self._resampler = av.AudioResampler( |
| format="s16", |
| layout="mono", |
| rate=16_000, |
| ) |
| try: |
| resampled = self._resampler.resample(frame) |
| for rf in resampled: |
| pcm = bytes(rf.planes[0]) |
| try: |
| self._frame_queue.put_nowait(pcm) |
| except asyncio.QueueFull: |
| pass |
| except Exception as exc: |
| print(f"[WebRTC] Resample error: {exc}") |
|
|
| async def recv(self): |
| """Required override β delegates to the underlying track.""" |
| return await self._track.recv() |
|
|
| def stop_receiving(self) -> None: |
| if self._recv_task and not self._recv_task.done(): |
| self._recv_task.cancel() |
|
|
|
|
| |
| |
| |
|
|
| class _TurnPipeline: |
| """ |
| Runs one conversation turn: speech bytes β transcript β LLM stream β audio. |
| Designed to be created fresh per turn (or cancelled on barge-in). |
| """ |
|
|
| def __init__(self, ai_backend, data_channel, on_stt=None, on_token=None): |
| self._ai = ai_backend |
| self._channel = data_channel |
| self._on_stt = on_stt |
| self._on_token = on_token |
| self._stt = STTProcessor() |
| self._streamer = ParallelTTSStreamer() |
| self._cancelled = False |
| self._tasks: list[asyncio.Task] = [] |
|
|
| async def run(self, user_id: str, audio_bytes: bytes) -> None: |
| """Full pipeline: audio β STT β LLM+TTS (parallel).""" |
|
|
| |
| transcript = await self._stt.transcribe(audio_bytes) |
| if not transcript or self._cancelled: |
| self._send_ctrl({"type": "end"}) |
| return |
|
|
| if self._on_stt: |
| self._on_stt(transcript) |
| self._send_ctrl({"type": "stt", "text": transcript}) |
|
|
| |
| await asyncio.gather( |
| self._run_llm(user_id, transcript), |
| self._run_tts_delivery(), |
| return_exceptions=True, |
| ) |
|
|
| if not self._cancelled: |
| self._send_ctrl({"type": "end"}) |
|
|
| async def _run_llm(self, user_id: str, transcript: str) -> None: |
| """Stream LLM tokens β TTS streamer (concurrent with audio delivery).""" |
| full_text = "" |
| try: |
| stream = await self._ai.main(user_id, transcript) |
| async for token in stream: |
| if self._cancelled or not token: |
| break |
| full_text += token |
| if self._on_token: |
| self._on_token(token) |
| self._send_ctrl({"type": "llm_token", "token": token}) |
| except asyncio.CancelledError: |
| raise |
| except Exception as exc: |
| print(f"[Pipeline] LLM error: {exc}") |
| finally: |
| |
| if full_text and not self._cancelled: |
| await self._streamer.add_token(full_text) |
| await self._streamer.flush() |
|
|
| async def _run_tts_delivery(self) -> None: |
| """Forward audio chunks from TTS streamer to WebRTC data channel.""" |
| async for chunk in self._streamer.stream_audio(): |
| if self._cancelled: |
| break |
| self._send_audio(chunk) |
|
|
| def _send_ctrl(self, payload: dict) -> None: |
| if self._channel and self._channel.readyState == "open": |
| try: |
| self._channel.send(json.dumps(payload)) |
| except Exception: |
| pass |
|
|
| def _send_audio(self, data: bytes) -> None: |
| if self._channel and self._channel.readyState == "open": |
| try: |
| self._channel.send(data) |
| except Exception: |
| pass |
|
|
| async def cancel(self) -> None: |
| self._cancelled = True |
| await self._streamer.cancel() |
| for t in self._tasks: |
| t.cancel() |
| if self._tasks: |
| await asyncio.gather(*self._tasks, return_exceptions=True) |
|
|
|
|
| |
| |
| |
|
|
| class WebRTCSession: |
| """ |
| Manages one WebRTC peer connection: |
| β’ Handles ICE/SDP negotiation |
| β’ Receives audio track β VAD β STT queue |
| β’ Sends responses back via RTCDataChannel |
| β’ Supports barge-in (cancel active turn on new speech) |
| """ |
|
|
| def __init__(self, ai_backend) -> None: |
| if not AIORTC_AVAILABLE: |
| raise RuntimeError("aiortc is required for WebRTC mode") |
| self._ai = ai_backend |
| self.user_id = f"rtc_{uuid.uuid4().hex[:12]}" |
| self._pc = RTCPeerConnection() |
| self._channel = None |
| self._frame_q: asyncio.Queue = asyncio.Queue(maxsize=500) |
| self._vad = _VADSegmenter() |
| self._active_turn: Optional[_TurnPipeline] = None |
| self._active_task: Optional[asyncio.Task] = None |
| |
| self._receivers: list[AudioFrameReceiver] = [] |
| self._setup_pc() |
|
|
| def _setup_pc(self) -> None: |
| pc = self._pc |
|
|
| @pc.on("track") |
| def on_track(track): |
| if track.kind == "audio": |
| |
| receiver = AudioFrameReceiver(track, self._frame_q) |
| receiver.start_receiving() |
| self._receivers.append(receiver) |
| asyncio.ensure_future(self._frame_processor()) |
| print(f"[WebRTC] Audio track received β receiver started β") |
|
|
| @pc.on("datachannel") |
| def on_datachannel(channel): |
| self._channel = channel |
| print(f"[WebRTC] DataChannel open: {channel.label}") |
|
|
| @channel.on("message") |
| def on_message(msg): |
| try: |
| data = json.loads(msg) |
| if data.get("type") == "cancel": |
| asyncio.ensure_future(self._cancel_active()) |
| elif data.get("type") == "init" and data.get("user_id"): |
| self.user_id = str(data["user_id"])[:64] |
| except Exception: |
| pass |
|
|
| @pc.on("connectionstatechange") |
| async def on_state(): |
| print(f"[WebRTC] Connection state: {pc.connectionState}") |
| if pc.connectionState in ("failed", "closed"): |
| await self._cancel_active() |
|
|
| async def _frame_processor(self) -> None: |
| """Consume PCM frames from queue β VAD β dispatch turns.""" |
| while True: |
| try: |
| frame = await asyncio.wait_for(self._frame_q.get(), timeout=5.0) |
| except asyncio.TimeoutError: |
| continue |
| except Exception: |
| break |
|
|
| utterance = self._vad.process_frame(frame) |
| if utterance: |
| await self._dispatch_turn(utterance) |
|
|
| async def _dispatch_turn(self, audio_bytes: bytes) -> None: |
| """Barge-in aware: cancel current turn, start new one.""" |
| await self._cancel_active() |
|
|
| pipeline = _TurnPipeline( |
| ai_backend=self._ai, |
| data_channel=self._channel, |
| ) |
| self._active_turn = pipeline |
| self._active_task = asyncio.create_task( |
| pipeline.run(self.user_id, audio_bytes) |
| ) |
|
|
| async def _cancel_active(self) -> None: |
| if self._active_turn: |
| await self._active_turn.cancel() |
| self._active_turn = None |
| if self._active_task and not self._active_task.done(): |
| self._active_task.cancel() |
| try: |
| await self._active_task |
| except (asyncio.CancelledError, Exception): |
| pass |
| self._active_task = None |
|
|
| async def handle_offer(self, sdp: str, sdp_type: str) -> dict: |
| """Process SDP offer from browser. Returns SDP answer.""" |
| offer = RTCSessionDescription(sdp=sdp, type=sdp_type) |
| await self._pc.setRemoteDescription(offer) |
| answer = await self._pc.createAnswer() |
| await self._pc.setLocalDescription(answer) |
| return { |
| "sdp": self._pc.localDescription.sdp, |
| "type": self._pc.localDescription.type, |
| } |
|
|
| async def add_ice_candidate(self, candidate: dict) -> None: |
| """Forward browser ICE candidate to aiortc.""" |
| from aiortc import RTCIceCandidate |
| c = RTCIceCandidate( |
| component=candidate.get("component", 1), |
| foundation=candidate.get("foundation", ""), |
| ip=candidate.get("ip", ""), |
| port=candidate.get("port", 0), |
| priority=candidate.get("priority", 0), |
| protocol=candidate.get("protocol", "udp"), |
| type=candidate.get("type", "host"), |
| sdpMid=candidate.get("sdpMid"), |
| sdpMLineIndex=candidate.get("sdpMLineIndex"), |
| ) |
| await self._pc.addIceCandidate(c) |
|
|
| async def close(self) -> None: |
| for receiver in self._receivers: |
| receiver.stop_receiving() |
| self._receivers.clear() |
| await self._cancel_active() |
| await self._pc.close() |
|
|