Voice-AI-Agent / services /webrtc_pipeline.py
rakib72642's picture
checkpoint 2
77a79ae
"""
services/webrtc_pipeline.py β€” WebRTC Audio Pipeline + Full Parallelization
FIX-BUG3 (AudioFrameReceiver never driven):
In the original code, AudioFrameReceiver was instantiated but its recv()
method was never called. aiortc only delivers frames when a consumer calls
recv() in a loop. Without this, the frame queue was always empty β†’ no audio
reached the VAD β†’ no utterances β†’ zero voice responses via WebRTC.
Fix: spawn a coroutine (_recv_loop) that calls receiver.recv() continuously.
All other logic preserved.
"""
from __future__ import annotations
import asyncio
import json
import uuid
from typing import Optional
try:
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
from aiortc.contrib.media import MediaBlackhole
import av
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
print("[WebRTC] aiortc not installed β€” WebRTC pipeline unavailable. "
"Install: pip install aiortc")
try:
import webrtcvad
VAD_AVAILABLE = True
except ImportError:
VAD_AVAILABLE = False
print("[WebRTC] webrtcvad not installed β€” VAD unavailable.")
from services.stt import STTProcessor
from services.streaming import ParallelTTSStreamer
# ══════════════════════════════════════════════════════════════════════════════
# VAD SEGMENTER (PCM frames β†’ speech utterances)
# ══════════════════════════════════════════════════════════════════════════════
class _VADSegmenter:
"""
Accumulates raw 16-bit mono PCM frames.
Yields complete utterances when silence follows speech.
"""
def __init__(
self,
sample_rate: int = 16_000,
frame_ms: int = 20, # 20ms frames β€” aiortc default
aggressiveness: int = 2,
silence_limit: int = 12, # ~240ms silence β†’ end of utterance
) -> None:
self.sample_rate = sample_rate
self.frame_bytes = int(sample_rate * frame_ms / 1000) * 2 # 16-bit samples
self.silence_limit = silence_limit
self._vad = webrtcvad.Vad(aggressiveness) if VAD_AVAILABLE else None
self._buffer = bytearray()
self._silence_count = 0
self._active = False
def process_frame(self, pcm_frame: bytes) -> Optional[bytes]:
"""
Feed one 20ms PCM frame.
Returns a complete utterance bytes object when speech ends, else None.
"""
if self._vad is None:
# No VAD β€” buffer everything, flush after 3s
self._buffer.extend(pcm_frame)
if len(self._buffer) >= self.sample_rate * 3 * 2:
data = bytes(self._buffer)
self._buffer.clear()
return data
return None
# Pad or trim to exact frame size
frame = pcm_frame[:self.frame_bytes].ljust(self.frame_bytes, b'\x00')
try:
is_speech = self._vad.is_speech(frame, self.sample_rate)
except Exception:
is_speech = False
if is_speech:
self._buffer.extend(frame)
self._active = True
self._silence_count = 0
elif self._active:
self._buffer.extend(frame)
self._silence_count += 1
if self._active and self._silence_count >= self.silence_limit:
data = bytes(self._buffer)
self._buffer.clear()
self._silence_count = 0
self._active = False
return data
return None
# ══════════════════════════════════════════════════════════════════════════════
# AUDIO TRACK RECEIVER
# ══════════════════════════════════════════════════════════════════════════════
if AIORTC_AVAILABLE:
class AudioFrameReceiver(MediaStreamTrack):
"""
Wraps an incoming WebRTC audio track.
Resamples to 16kHz mono PCM and pushes frames into an asyncio.Queue.
IMPORTANT: call start_receiving() after construction to begin
consuming frames from the underlying track via recv().
"""
kind = "audio"
def __init__(self, track: MediaStreamTrack, frame_queue: asyncio.Queue) -> None:
super().__init__()
self._track = track
self._frame_queue = frame_queue
self._resampler: Optional[av.AudioResampler] = None
self._recv_task: Optional[asyncio.Task] = None
def start_receiving(self) -> None:
"""
FIX-BUG3: Spawn the recv() loop so the track actually delivers frames.
Without this, _frame_queue stays empty forever.
"""
if self._recv_task is None or self._recv_task.done():
self._recv_task = asyncio.ensure_future(self._recv_loop())
async def _recv_loop(self) -> None:
"""Continuously consume frames from the remote track."""
while True:
try:
frame = await self._track.recv()
except Exception as exc:
print(f"[WebRTC] AudioFrameReceiver: track ended ({exc})")
break
if self._resampler is None:
self._resampler = av.AudioResampler(
format="s16",
layout="mono",
rate=16_000,
)
try:
resampled = self._resampler.resample(frame)
for rf in resampled:
pcm = bytes(rf.planes[0])
try:
self._frame_queue.put_nowait(pcm)
except asyncio.QueueFull:
pass # Drop frame under backpressure β€” prefer real-time
except Exception as exc:
print(f"[WebRTC] Resample error: {exc}")
async def recv(self):
"""Required override β€” delegates to the underlying track."""
return await self._track.recv()
def stop_receiving(self) -> None:
if self._recv_task and not self._recv_task.done():
self._recv_task.cancel()
# ══════════════════════════════════════════════════════════════════════════════
# TURN PIPELINE (STT β†’ LLM β†’ TTS, all parallel)
# ══════════════════════════════════════════════════════════════════════════════
class _TurnPipeline:
"""
Runs one conversation turn: speech bytes β†’ transcript β†’ LLM stream β†’ audio.
Designed to be created fresh per turn (or cancelled on barge-in).
"""
def __init__(self, ai_backend, data_channel, on_stt=None, on_token=None):
self._ai = ai_backend
self._channel = data_channel # RTCDataChannel for audio delivery
self._on_stt = on_stt # optional callback(str)
self._on_token = on_token # optional callback(str)
self._stt = STTProcessor()
self._streamer = ParallelTTSStreamer()
self._cancelled = False
self._tasks: list[asyncio.Task] = []
async def run(self, user_id: str, audio_bytes: bytes) -> None:
"""Full pipeline: audio β†’ STT β†’ LLM+TTS (parallel)."""
# ── Phase 1: STT (GPU-batched) ────────────────────────────────────────
transcript = await self._stt.transcribe(audio_bytes)
if not transcript or self._cancelled:
self._send_ctrl({"type": "end"})
return
if self._on_stt:
self._on_stt(transcript)
self._send_ctrl({"type": "stt", "text": transcript})
# ── Phase 2: LLM + TTS in parallel ───────────────────────────────────
await asyncio.gather(
self._run_llm(user_id, transcript),
self._run_tts_delivery(),
return_exceptions=True,
)
if not self._cancelled:
self._send_ctrl({"type": "end"})
async def _run_llm(self, user_id: str, transcript: str) -> None:
"""Stream LLM tokens β†’ TTS streamer (concurrent with audio delivery)."""
full_text = ""
try:
stream = await self._ai.main(user_id, transcript)
async for token in stream:
if self._cancelled or not token:
break
full_text += token
if self._on_token:
self._on_token(token)
self._send_ctrl({"type": "llm_token", "token": token})
except asyncio.CancelledError:
raise
except Exception as exc:
print(f"[Pipeline] LLM error: {exc}")
finally:
# Feed the completed response to TTS for more reliable synthesis.
if full_text and not self._cancelled:
await self._streamer.add_token(full_text)
await self._streamer.flush()
async def _run_tts_delivery(self) -> None:
"""Forward audio chunks from TTS streamer to WebRTC data channel."""
async for chunk in self._streamer.stream_audio():
if self._cancelled:
break
self._send_audio(chunk)
def _send_ctrl(self, payload: dict) -> None:
if self._channel and self._channel.readyState == "open":
try:
self._channel.send(json.dumps(payload))
except Exception:
pass
def _send_audio(self, data: bytes) -> None:
if self._channel and self._channel.readyState == "open":
try:
self._channel.send(data)
except Exception:
pass
async def cancel(self) -> None:
self._cancelled = True
await self._streamer.cancel()
for t in self._tasks:
t.cancel()
if self._tasks:
await asyncio.gather(*self._tasks, return_exceptions=True)
# ══════════════════════════════════════════════════════════════════════════════
# WEBRTC SESSION HANDLER
# ══════════════════════════════════════════════════════════════════════════════
class WebRTCSession:
"""
Manages one WebRTC peer connection:
β€’ Handles ICE/SDP negotiation
β€’ Receives audio track β†’ VAD β†’ STT queue
β€’ Sends responses back via RTCDataChannel
β€’ Supports barge-in (cancel active turn on new speech)
"""
def __init__(self, ai_backend) -> None:
if not AIORTC_AVAILABLE:
raise RuntimeError("aiortc is required for WebRTC mode")
self._ai = ai_backend
self.user_id = f"rtc_{uuid.uuid4().hex[:12]}"
self._pc = RTCPeerConnection()
self._channel = None
self._frame_q: asyncio.Queue = asyncio.Queue(maxsize=500)
self._vad = _VADSegmenter()
self._active_turn: Optional[_TurnPipeline] = None
self._active_task: Optional[asyncio.Task] = None
# Keep references to receivers so they are not garbage-collected
self._receivers: list[AudioFrameReceiver] = []
self._setup_pc()
def _setup_pc(self) -> None:
pc = self._pc
@pc.on("track")
def on_track(track):
if track.kind == "audio":
# FIX-BUG3: create receiver AND start its recv() loop
receiver = AudioFrameReceiver(track, self._frame_q)
receiver.start_receiving()
self._receivers.append(receiver) # prevent GC
asyncio.ensure_future(self._frame_processor())
print(f"[WebRTC] Audio track received β€” receiver started βœ“")
@pc.on("datachannel")
def on_datachannel(channel):
self._channel = channel
print(f"[WebRTC] DataChannel open: {channel.label}")
@channel.on("message")
def on_message(msg):
try:
data = json.loads(msg)
if data.get("type") == "cancel":
asyncio.ensure_future(self._cancel_active())
elif data.get("type") == "init" and data.get("user_id"):
self.user_id = str(data["user_id"])[:64]
except Exception:
pass
@pc.on("connectionstatechange")
async def on_state():
print(f"[WebRTC] Connection state: {pc.connectionState}")
if pc.connectionState in ("failed", "closed"):
await self._cancel_active()
async def _frame_processor(self) -> None:
"""Consume PCM frames from queue β†’ VAD β†’ dispatch turns."""
while True:
try:
frame = await asyncio.wait_for(self._frame_q.get(), timeout=5.0)
except asyncio.TimeoutError:
continue
except Exception:
break
utterance = self._vad.process_frame(frame)
if utterance:
await self._dispatch_turn(utterance)
async def _dispatch_turn(self, audio_bytes: bytes) -> None:
"""Barge-in aware: cancel current turn, start new one."""
await self._cancel_active()
pipeline = _TurnPipeline(
ai_backend=self._ai,
data_channel=self._channel,
)
self._active_turn = pipeline
self._active_task = asyncio.create_task(
pipeline.run(self.user_id, audio_bytes)
)
async def _cancel_active(self) -> None:
if self._active_turn:
await self._active_turn.cancel()
self._active_turn = None
if self._active_task and not self._active_task.done():
self._active_task.cancel()
try:
await self._active_task
except (asyncio.CancelledError, Exception):
pass
self._active_task = None
async def handle_offer(self, sdp: str, sdp_type: str) -> dict:
"""Process SDP offer from browser. Returns SDP answer."""
offer = RTCSessionDescription(sdp=sdp, type=sdp_type)
await self._pc.setRemoteDescription(offer)
answer = await self._pc.createAnswer()
await self._pc.setLocalDescription(answer)
return {
"sdp": self._pc.localDescription.sdp,
"type": self._pc.localDescription.type,
}
async def add_ice_candidate(self, candidate: dict) -> None:
"""Forward browser ICE candidate to aiortc."""
from aiortc import RTCIceCandidate
c = RTCIceCandidate(
component=candidate.get("component", 1),
foundation=candidate.get("foundation", ""),
ip=candidate.get("ip", ""),
port=candidate.get("port", 0),
priority=candidate.get("priority", 0),
protocol=candidate.get("protocol", "udp"),
type=candidate.get("type", "host"),
sdpMid=candidate.get("sdpMid"),
sdpMLineIndex=candidate.get("sdpMLineIndex"),
)
await self._pc.addIceCandidate(c)
async def close(self) -> None:
for receiver in self._receivers:
receiver.stop_receiving()
self._receivers.clear()
await self._cancel_active()
await self._pc.close()