"""Speech-to-Speech pipeline for audio-in, audio-out generation. This pipeline combines ASR (speech-to-text) with TTS (text-to-speech) to create a unified speech-to-speech interface that can be used with HuggingFace's pipeline API. Usage: from transformers import pipeline # Load as speech-to-speech pipeline pipe = pipeline("speech-to-speech", model="mazesmazes/tiny-audio-omni", trust_remote_code=True) # Process audio (outputs 48kHz by default for browser compatibility) result = pipe("audio.wav") # Returns: {"text": "transcription", "audio": np.array, "sampling_rate": 48000} # With custom TTS voice result = pipe("audio.wav", tts_voice="af_bella") # Output at native TTS rate (24kHz) without resampling result = pipe("audio.wav", output_sample_rate=24000) # Get only audio output (for streaming/playback) audio, sr = result["audio"], result["sampling_rate"] # Streaming with built-in VAD (Voice Activity Detection) for result in pipe.stream(audio_chunk_generator()): print(result["text"]) play_audio(result["audio"], result["sampling_rate"]) """ from collections.abc import Generator, Iterator from dataclasses import dataclass, field from pathlib import Path from typing import Any import numpy as np import scipy.signal import torch from transformers import Pipeline from transformers.pipelines.audio_utils import ffmpeg_read try: from .asr_modeling import ASRModel from .asr_pipeline import _truncate_repetitions, strip_thinking except ImportError: from asr_modeling import ASRModel # type: ignore[no-redef] from asr_pipeline import _truncate_repetitions, strip_thinking # type: ignore[no-redef] __all__ = ["SpeechToSpeechPipeline", "VADConfig"] # Default TTS settings DEFAULT_TTS_VOICE = "af_heart" TTS_SAMPLE_RATE = 24000 # Native Kokoro TTS sample rate DEFAULT_OUTPUT_SAMPLE_RATE = 48000 # Browser-friendly sample rate # Default VAD settings DEFAULT_VAD_THRESHOLD = 0.5 DEFAULT_SILENCE_DURATION_MS = 700 DEFAULT_INPUT_SAMPLE_RATE = 16000 @dataclass class VADConfig: """Configuration for Voice Activity Detection. Args: threshold: VAD probability threshold (0.0-1.0). Higher = stricter. silence_duration_ms: Milliseconds of silence before end-of-speech. sample_rate: Expected input audio sample rate. """ threshold: float = DEFAULT_VAD_THRESHOLD silence_duration_ms: int = DEFAULT_SILENCE_DURATION_MS sample_rate: int = DEFAULT_INPUT_SAMPLE_RATE @dataclass class _VADState: """Internal state for VAD streaming.""" is_speaking: bool = False silence_frames: int = 0 audio_buffer: list[np.ndarray] = field(default_factory=list) def reset(self): """Reset state after processing an utterance.""" self.is_speaking = False self.silence_frames = 0 self.audio_buffer = [] class SpeechToSpeechPipeline(Pipeline): """HuggingFace pipeline for speech-to-speech generation. This pipeline takes audio input, transcribes it using an ASR model, and synthesizes the response as speech using Kokoro TTS. Args: model: ASRModel instance for transcription tts_voice: Default Kokoro TTS voice ID (default: "af_heart") output_sample_rate: Output audio sample rate (default: 48000 for browser compatibility) **kwargs: Additional arguments passed to Pipeline base class Example: >>> from transformers import pipeline >>> pipe = pipeline("speech-to-speech", model="mazesmazes/tiny-audio-omni", trust_remote_code=True) >>> result = pipe("audio.wav") >>> result["text"] # Transcription/response text >>> result["audio"] # Audio as numpy array (48kHz) >>> result["sampling_rate"] # 48000 """ model: ASRModel def __init__( self, model: ASRModel, tts_voice: str = DEFAULT_TTS_VOICE, output_sample_rate: int = DEFAULT_OUTPUT_SAMPLE_RATE, vad_config: VADConfig | None = None, **kwargs, ): """Initialize Speech-to-Speech pipeline.""" feature_extractor = kwargs.pop("feature_extractor", None) tokenizer = kwargs.pop("tokenizer", model.tokenizer) if feature_extractor is None: feature_extractor = model.get_processor().feature_extractor super().__init__( model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs, ) self.tts_voice = tts_voice self.output_sample_rate = output_sample_rate self.vad_config = vad_config or VADConfig() self._tts_pipeline = None self._vad_model = None self._vad_utils = None @property def tts_pipeline(self): """Lazy-load Kokoro TTS pipeline on first use.""" if self._tts_pipeline is None: try: from kokoro import KPipeline self._tts_pipeline = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M") except ImportError as e: raise ImportError( "Kokoro TTS is required for speech-to-speech. " "Install with: pip install kokoro>=0.9.2\n" "Also requires espeak-ng: apt-get install espeak-ng" ) from e return self._tts_pipeline @property def vad_model(self): """Lazy-load Silero VAD model on first use.""" if self._vad_model is None: self._vad_model, self._vad_utils = torch.hub.load( repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False, ) return self._vad_model @property def vad_utils(self): """Get VAD utilities (loads model if needed).""" if self._vad_utils is None: # Access vad_model to trigger loading _ = self.vad_model return self._vad_utils def stream( self, audio_chunks: Iterator[np.ndarray], tts_voice: str | None = None, output_sample_rate: int | None = None, vad_config: VADConfig | None = None, ) -> Generator[dict[str, Any], None, None]: """Process streaming audio with VAD and yield responses. Takes an iterator of audio chunks, detects speech using Silero VAD, and yields responses when speech ends (after silence threshold). Args: audio_chunks: Iterator yielding audio chunks as numpy arrays (float32, 16kHz). Each chunk should be ~100-500ms of audio. tts_voice: Kokoro voice ID for TTS output (default: self.tts_voice) output_sample_rate: Output sample rate (default: self.output_sample_rate) vad_config: VAD configuration (default: self.vad_config) Yields: Dict with 'text', 'audio', and 'sampling_rate' for each detected utterance. Example: >>> def audio_generator(): ... while True: ... chunk = get_audio_chunk() # Get ~100ms of audio ... if chunk is None: ... break ... yield chunk >>> for result in pipe.stream(audio_generator()): ... print(result["text"]) ... play_audio(result["audio"], result["sampling_rate"]) """ config = vad_config or self.vad_config voice = tts_voice or self.tts_voice target_sr = output_sample_rate or self.output_sample_rate state = _VADState() vad_utils = self.vad_utils if vad_utils is None: raise RuntimeError("Failed to load Silero VAD model") get_speech_timestamps = vad_utils[0] # Calculate silence threshold in frames # Assuming ~100ms chunks at 16kHz = 1600 samples per chunk # silence_duration_ms / chunk_duration_ms = number of silent chunks chunk_duration_ms = 100 # Approximate, will be calculated per chunk silence_threshold = max(1, config.silence_duration_ms // chunk_duration_ms) for chunk in audio_chunks: # Ensure chunk is float32 if chunk.dtype != np.float32: chunk = chunk.astype(np.float32) # Normalize if needed (int16 range to float32) if chunk.max() > 1.0 or chunk.min() < -1.0: chunk = chunk / 32768.0 # Update chunk duration estimate for silence threshold chunk_duration_ms = len(chunk) / config.sample_rate * 1000 silence_threshold = max(1, int(config.silence_duration_ms / chunk_duration_ms)) # Run VAD speech_timestamps = get_speech_timestamps( torch.from_numpy(chunk), self.vad_model, sampling_rate=config.sample_rate, threshold=config.threshold, ) has_speech = len(speech_timestamps) > 0 if has_speech: if not state.is_speaking: state.is_speaking = True state.audio_buffer = [] state.audio_buffer.append(chunk) state.silence_frames = 0 elif state.is_speaking: state.audio_buffer.append(chunk) state.silence_frames += 1 if state.silence_frames >= silence_threshold: # End of speech detected - process the utterance if state.audio_buffer: full_audio = np.concatenate(state.audio_buffer) result = self( {"array": full_audio, "sampling_rate": config.sample_rate}, tts_voice=voice, output_sample_rate=target_sr, ) yield result state.reset() def _sanitize_parameters( self, tts_voice: str | None = None, output_sample_rate: int | None = None, return_text_only: bool = False, user_prompt: str | None = None, **kwargs, ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: """Sanitize and route parameters to preprocessing, forward, and postprocessing.""" preprocess_kwargs: dict[str, Any] = {} forward_kwargs: dict[str, Any] = {} postprocess_kwargs: dict[str, Any] = {} if tts_voice is not None: postprocess_kwargs["tts_voice"] = tts_voice if output_sample_rate is not None: postprocess_kwargs["output_sample_rate"] = output_sample_rate if return_text_only: postprocess_kwargs["return_text_only"] = return_text_only if user_prompt is not None: forward_kwargs["user_prompt"] = user_prompt return preprocess_kwargs, forward_kwargs, postprocess_kwargs def preprocess(self, inputs, **kwargs) -> dict[str, Any]: """Preprocess audio inputs for the model. Handles various input formats: - File path (str) - Dict with 'array' and 'sampling_rate' - Dict with 'raw' audio bytes - Raw numpy array - Bytes Returns: Dict with input_features and attention_mask for the model """ # Extract audio array from various formats audio_array = self._extract_audio(inputs) if audio_array is None: raise ValueError(f"Could not extract audio from input type: {type(inputs)}") # Use feature extractor to get mel features processed = self.feature_extractor( audio_array, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt", return_attention_mask=True, ) return { "input_features": processed.input_features, "attention_mask": processed.attention_mask, } def _forward(self, model_inputs: dict, user_prompt: str | None = None) -> dict[str, Any]: """Run ASR model to generate text from audio. Args: model_inputs: Dict with input_features and attention_mask user_prompt: Optional custom prompt for the model Returns: Dict with generated token IDs """ input_features = model_inputs["input_features"].to(self.model.device) attention_mask = model_inputs["attention_mask"].to(self.model.device) # Set custom prompt if provided original_prompt = None if user_prompt: original_prompt = self.model.TRANSCRIBE_PROMPT self.model.TRANSCRIBE_PROMPT = user_prompt try: generated_ids = self.model.generate( input_features=input_features, audio_attention_mask=attention_mask, ) finally: if original_prompt is not None: self.model.TRANSCRIBE_PROMPT = original_prompt return {"tokens": generated_ids} def postprocess( self, model_outputs: dict, tts_voice: str | None = None, output_sample_rate: int | None = None, return_text_only: bool = False, ) -> dict[str, Any]: """Convert model output to text and synthesize speech. Args: model_outputs: Dict with 'tokens' containing generated IDs tts_voice: Kokoro voice ID (default: self.tts_voice) output_sample_rate: Output sample rate (default: self.output_sample_rate) return_text_only: If True, skip TTS and return only text Returns: Dict with 'text', 'audio' (numpy array), and 'sampling_rate' """ target_sr = output_sample_rate or self.output_sample_rate tokens = model_outputs.get("tokens") if tokens is None: return { "text": "", "audio": np.array([], dtype=np.float32), "sampling_rate": target_sr, } # Convert tokens to text if torch.is_tensor(tokens): tokens = tokens.cpu() if tokens.dim() > 1: tokens = tokens[0] # Filter EOS tokens if hasattr(self.model, "generation_config") and self.model.generation_config is not None: eos_ids = self.model.generation_config.eos_token_id if eos_ids is not None: eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids} tokens = [t for t in tokens.tolist() if t not in eos_set] text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() text = strip_thinking(text) text = _truncate_repetitions(text) result = {"text": text} # Synthesize speech unless text-only requested if not return_text_only: voice = tts_voice or self.tts_voice audio = self._synthesize_speech(text, voice) # Resample if target sample rate differs from native TTS rate audio = self._resample_audio(audio, TTS_SAMPLE_RATE, target_sr) result["audio"] = audio result["sampling_rate"] = target_sr return result def _synthesize_speech(self, text: str, voice: str) -> np.ndarray: """Synthesize speech from text using Kokoro TTS. Args: text: Text to synthesize voice: Kokoro voice ID Returns: Audio as numpy array (float32, 24kHz native TTS rate) """ if not text or not text.strip(): return np.array([], dtype=np.float32) try: audio_chunks = [] for _, _, audio in self.tts_pipeline(text, voice=voice): audio_chunks.append(audio) if audio_chunks: return np.concatenate(audio_chunks) except Exception: pass return np.array([], dtype=np.float32) def _resample_audio(self, audio: np.ndarray, from_sr: int, to_sr: int) -> np.ndarray: """Resample audio to target sample rate. Args: audio: Input audio array from_sr: Source sample rate to_sr: Target sample rate Returns: Resampled audio array """ if len(audio) == 0 or from_sr == to_sr: return audio num_samples = int(len(audio) * to_sr / from_sr) return scipy.signal.resample(audio, num_samples).astype(np.float32) def text_to_speech( self, text: str, voice: str | None = None, output_sample_rate: int | None = None, ) -> dict[str, Any]: """Convert text to speech using Kokoro TTS. This is a convenience method for generating audio from text without going through the full speech-to-speech pipeline. Args: text: Text to synthesize voice: Kokoro voice ID (default: self.tts_voice) output_sample_rate: Output sample rate (default: self.output_sample_rate) Returns: Dict with 'audio' (numpy array) and 'sampling_rate' keys """ voice = voice or self.tts_voice target_sr = output_sample_rate or self.output_sample_rate audio = self._synthesize_speech(text, voice) audio = self._resample_audio(audio, TTS_SAMPLE_RATE, target_sr) return {"audio": audio, "sampling_rate": target_sr} def _extract_audio(self, inputs) -> np.ndarray | None: """Extract audio array from various input formats. Args: inputs: Audio input in various formats Returns: Audio as numpy array (float32) or None if extraction fails """ if isinstance(inputs, dict): if "array" in inputs: audio = inputs["array"] if isinstance(audio, np.ndarray): return audio.astype(np.float32) if audio.dtype != np.float32 else audio return np.array(audio, dtype=np.float32) if "raw" in inputs: audio = inputs["raw"] if isinstance(audio, np.ndarray): return audio.astype(np.float32) if audio.dtype != np.float32 else audio return np.array(audio, dtype=np.float32) elif isinstance(inputs, str): # File path with Path(inputs).open("rb") as f: return ffmpeg_read(f.read(), sampling_rate=16000) elif isinstance(inputs, bytes): return ffmpeg_read(inputs, sampling_rate=16000) elif isinstance(inputs, np.ndarray): return inputs.astype(np.float32) if inputs.dtype != np.float32 else inputs return None def __call__(self, inputs, **kwargs) -> dict[str, Any]: """Process audio input and return speech output. Args: inputs: Audio input (file path, dict with array, numpy array, or bytes) tts_voice: Kokoro voice ID for TTS output (default: "af_heart") return_text_only: If True, skip TTS and return only transcription user_prompt: Custom prompt for the model Returns: Dict with: - 'text': Transcription/response text - 'audio': Synthesized speech as numpy array (float32) - 'sampling_rate': Audio sample rate (24000) """ return super().__call__(inputs, **kwargs)