| """Speech-to-Speech pipeline for audio-in, audio-out generation. |
| |
| This pipeline combines ASR (speech-to-text) with TTS (text-to-speech) to create |
| a unified speech-to-speech interface that can be used with HuggingFace's pipeline API. |
| |
| Usage: |
| from transformers import pipeline |
| |
| # Load as speech-to-speech pipeline |
| pipe = pipeline("speech-to-speech", model="mazesmazes/tiny-audio-omni", trust_remote_code=True) |
| |
| # Process audio (outputs 48kHz by default for browser compatibility) |
| result = pipe("audio.wav") |
| # Returns: {"text": "transcription", "audio": np.array, "sampling_rate": 48000} |
| |
| # With custom TTS voice |
| result = pipe("audio.wav", tts_voice="af_bella") |
| |
| # Output at native TTS rate (24kHz) without resampling |
| result = pipe("audio.wav", output_sample_rate=24000) |
| |
| # Get only audio output (for streaming/playback) |
| audio, sr = result["audio"], result["sampling_rate"] |
| |
| # Streaming with built-in VAD (Voice Activity Detection) |
| for result in pipe.stream(audio_chunk_generator()): |
| print(result["text"]) |
| play_audio(result["audio"], result["sampling_rate"]) |
| """ |
|
|
| from collections.abc import Generator, Iterator |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import scipy.signal |
| import torch |
| from transformers import Pipeline |
| from transformers.pipelines.audio_utils import ffmpeg_read |
|
|
| try: |
| from .asr_modeling import ASRModel |
| from .asr_pipeline import _truncate_repetitions, strip_thinking |
| except ImportError: |
| from asr_modeling import ASRModel |
| from asr_pipeline import _truncate_repetitions, strip_thinking |
|
|
| __all__ = ["SpeechToSpeechPipeline", "VADConfig"] |
|
|
| |
| DEFAULT_TTS_VOICE = "af_heart" |
| TTS_SAMPLE_RATE = 24000 |
| DEFAULT_OUTPUT_SAMPLE_RATE = 48000 |
|
|
| |
| DEFAULT_VAD_THRESHOLD = 0.5 |
| DEFAULT_SILENCE_DURATION_MS = 700 |
| DEFAULT_INPUT_SAMPLE_RATE = 16000 |
|
|
|
|
| @dataclass |
| class VADConfig: |
| """Configuration for Voice Activity Detection. |
| |
| Args: |
| threshold: VAD probability threshold (0.0-1.0). Higher = stricter. |
| silence_duration_ms: Milliseconds of silence before end-of-speech. |
| sample_rate: Expected input audio sample rate. |
| """ |
|
|
| threshold: float = DEFAULT_VAD_THRESHOLD |
| silence_duration_ms: int = DEFAULT_SILENCE_DURATION_MS |
| sample_rate: int = DEFAULT_INPUT_SAMPLE_RATE |
|
|
|
|
| @dataclass |
| class _VADState: |
| """Internal state for VAD streaming.""" |
|
|
| is_speaking: bool = False |
| silence_frames: int = 0 |
| audio_buffer: list[np.ndarray] = field(default_factory=list) |
|
|
| def reset(self): |
| """Reset state after processing an utterance.""" |
| self.is_speaking = False |
| self.silence_frames = 0 |
| self.audio_buffer = [] |
|
|
|
|
| class SpeechToSpeechPipeline(Pipeline): |
| """HuggingFace pipeline for speech-to-speech generation. |
| |
| This pipeline takes audio input, transcribes it using an ASR model, |
| and synthesizes the response as speech using Kokoro TTS. |
| |
| Args: |
| model: ASRModel instance for transcription |
| tts_voice: Default Kokoro TTS voice ID (default: "af_heart") |
| output_sample_rate: Output audio sample rate (default: 48000 for browser compatibility) |
| **kwargs: Additional arguments passed to Pipeline base class |
| |
| Example: |
| >>> from transformers import pipeline |
| >>> pipe = pipeline("speech-to-speech", model="mazesmazes/tiny-audio-omni", trust_remote_code=True) |
| >>> result = pipe("audio.wav") |
| >>> result["text"] # Transcription/response text |
| >>> result["audio"] # Audio as numpy array (48kHz) |
| >>> result["sampling_rate"] # 48000 |
| """ |
|
|
| model: ASRModel |
|
|
| def __init__( |
| self, |
| model: ASRModel, |
| tts_voice: str = DEFAULT_TTS_VOICE, |
| output_sample_rate: int = DEFAULT_OUTPUT_SAMPLE_RATE, |
| vad_config: VADConfig | None = None, |
| **kwargs, |
| ): |
| """Initialize Speech-to-Speech pipeline.""" |
| feature_extractor = kwargs.pop("feature_extractor", None) |
| tokenizer = kwargs.pop("tokenizer", model.tokenizer) |
|
|
| if feature_extractor is None: |
| feature_extractor = model.get_processor().feature_extractor |
|
|
| super().__init__( |
| model=model, |
| feature_extractor=feature_extractor, |
| tokenizer=tokenizer, |
| **kwargs, |
| ) |
| self.tts_voice = tts_voice |
| self.output_sample_rate = output_sample_rate |
| self.vad_config = vad_config or VADConfig() |
| self._tts_pipeline = None |
| self._vad_model = None |
| self._vad_utils = None |
|
|
| @property |
| def tts_pipeline(self): |
| """Lazy-load Kokoro TTS pipeline on first use.""" |
| if self._tts_pipeline is None: |
| try: |
| from kokoro import KPipeline |
|
|
| self._tts_pipeline = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M") |
| except ImportError as e: |
| raise ImportError( |
| "Kokoro TTS is required for speech-to-speech. " |
| "Install with: pip install kokoro>=0.9.2\n" |
| "Also requires espeak-ng: apt-get install espeak-ng" |
| ) from e |
| return self._tts_pipeline |
|
|
| @property |
| def vad_model(self): |
| """Lazy-load Silero VAD model on first use.""" |
| if self._vad_model is None: |
| self._vad_model, self._vad_utils = torch.hub.load( |
| repo_or_dir="snakers4/silero-vad", |
| model="silero_vad", |
| force_reload=False, |
| ) |
| return self._vad_model |
|
|
| @property |
| def vad_utils(self): |
| """Get VAD utilities (loads model if needed).""" |
| if self._vad_utils is None: |
| |
| _ = self.vad_model |
| return self._vad_utils |
|
|
| def stream( |
| self, |
| audio_chunks: Iterator[np.ndarray], |
| tts_voice: str | None = None, |
| output_sample_rate: int | None = None, |
| vad_config: VADConfig | None = None, |
| ) -> Generator[dict[str, Any], None, None]: |
| """Process streaming audio with VAD and yield responses. |
| |
| Takes an iterator of audio chunks, detects speech using Silero VAD, |
| and yields responses when speech ends (after silence threshold). |
| |
| Args: |
| audio_chunks: Iterator yielding audio chunks as numpy arrays (float32, 16kHz). |
| Each chunk should be ~100-500ms of audio. |
| tts_voice: Kokoro voice ID for TTS output (default: self.tts_voice) |
| output_sample_rate: Output sample rate (default: self.output_sample_rate) |
| vad_config: VAD configuration (default: self.vad_config) |
| |
| Yields: |
| Dict with 'text', 'audio', and 'sampling_rate' for each detected utterance. |
| |
| Example: |
| >>> def audio_generator(): |
| ... while True: |
| ... chunk = get_audio_chunk() # Get ~100ms of audio |
| ... if chunk is None: |
| ... break |
| ... yield chunk |
| >>> for result in pipe.stream(audio_generator()): |
| ... print(result["text"]) |
| ... play_audio(result["audio"], result["sampling_rate"]) |
| """ |
| config = vad_config or self.vad_config |
| voice = tts_voice or self.tts_voice |
| target_sr = output_sample_rate or self.output_sample_rate |
|
|
| state = _VADState() |
| vad_utils = self.vad_utils |
| if vad_utils is None: |
| raise RuntimeError("Failed to load Silero VAD model") |
| get_speech_timestamps = vad_utils[0] |
|
|
| |
| |
| |
| chunk_duration_ms = 100 |
| silence_threshold = max(1, config.silence_duration_ms // chunk_duration_ms) |
|
|
| for chunk in audio_chunks: |
| |
| if chunk.dtype != np.float32: |
| chunk = chunk.astype(np.float32) |
|
|
| |
| if chunk.max() > 1.0 or chunk.min() < -1.0: |
| chunk = chunk / 32768.0 |
|
|
| |
| chunk_duration_ms = len(chunk) / config.sample_rate * 1000 |
| silence_threshold = max(1, int(config.silence_duration_ms / chunk_duration_ms)) |
|
|
| |
| speech_timestamps = get_speech_timestamps( |
| torch.from_numpy(chunk), |
| self.vad_model, |
| sampling_rate=config.sample_rate, |
| threshold=config.threshold, |
| ) |
| has_speech = len(speech_timestamps) > 0 |
|
|
| if has_speech: |
| if not state.is_speaking: |
| state.is_speaking = True |
| state.audio_buffer = [] |
| state.audio_buffer.append(chunk) |
| state.silence_frames = 0 |
| elif state.is_speaking: |
| state.audio_buffer.append(chunk) |
| state.silence_frames += 1 |
|
|
| if state.silence_frames >= silence_threshold: |
| |
| if state.audio_buffer: |
| full_audio = np.concatenate(state.audio_buffer) |
| result = self( |
| {"array": full_audio, "sampling_rate": config.sample_rate}, |
| tts_voice=voice, |
| output_sample_rate=target_sr, |
| ) |
| yield result |
|
|
| state.reset() |
|
|
| def _sanitize_parameters( |
| self, |
| tts_voice: str | None = None, |
| output_sample_rate: int | None = None, |
| return_text_only: bool = False, |
| user_prompt: str | None = None, |
| **kwargs, |
| ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: |
| """Sanitize and route parameters to preprocessing, forward, and postprocessing.""" |
| preprocess_kwargs: dict[str, Any] = {} |
| forward_kwargs: dict[str, Any] = {} |
| postprocess_kwargs: dict[str, Any] = {} |
|
|
| if tts_voice is not None: |
| postprocess_kwargs["tts_voice"] = tts_voice |
| if output_sample_rate is not None: |
| postprocess_kwargs["output_sample_rate"] = output_sample_rate |
| if return_text_only: |
| postprocess_kwargs["return_text_only"] = return_text_only |
| if user_prompt is not None: |
| forward_kwargs["user_prompt"] = user_prompt |
|
|
| return preprocess_kwargs, forward_kwargs, postprocess_kwargs |
|
|
| def preprocess(self, inputs, **kwargs) -> dict[str, Any]: |
| """Preprocess audio inputs for the model. |
| |
| Handles various input formats: |
| - File path (str) |
| - Dict with 'array' and 'sampling_rate' |
| - Dict with 'raw' audio bytes |
| - Raw numpy array |
| - Bytes |
| |
| Returns: |
| Dict with input_features and attention_mask for the model |
| """ |
| |
| audio_array = self._extract_audio(inputs) |
|
|
| if audio_array is None: |
| raise ValueError(f"Could not extract audio from input type: {type(inputs)}") |
|
|
| |
| processed = self.feature_extractor( |
| audio_array, |
| sampling_rate=self.feature_extractor.sampling_rate, |
| return_tensors="pt", |
| return_attention_mask=True, |
| ) |
|
|
| return { |
| "input_features": processed.input_features, |
| "attention_mask": processed.attention_mask, |
| } |
|
|
| def _forward(self, model_inputs: dict, user_prompt: str | None = None) -> dict[str, Any]: |
| """Run ASR model to generate text from audio. |
| |
| Args: |
| model_inputs: Dict with input_features and attention_mask |
| user_prompt: Optional custom prompt for the model |
| |
| Returns: |
| Dict with generated token IDs |
| """ |
| input_features = model_inputs["input_features"].to(self.model.device) |
| attention_mask = model_inputs["attention_mask"].to(self.model.device) |
|
|
| |
| original_prompt = None |
| if user_prompt: |
| original_prompt = self.model.TRANSCRIBE_PROMPT |
| self.model.TRANSCRIBE_PROMPT = user_prompt |
|
|
| try: |
| generated_ids = self.model.generate( |
| input_features=input_features, |
| audio_attention_mask=attention_mask, |
| ) |
| finally: |
| if original_prompt is not None: |
| self.model.TRANSCRIBE_PROMPT = original_prompt |
|
|
| return {"tokens": generated_ids} |
|
|
| def postprocess( |
| self, |
| model_outputs: dict, |
| tts_voice: str | None = None, |
| output_sample_rate: int | None = None, |
| return_text_only: bool = False, |
| ) -> dict[str, Any]: |
| """Convert model output to text and synthesize speech. |
| |
| Args: |
| model_outputs: Dict with 'tokens' containing generated IDs |
| tts_voice: Kokoro voice ID (default: self.tts_voice) |
| output_sample_rate: Output sample rate (default: self.output_sample_rate) |
| return_text_only: If True, skip TTS and return only text |
| |
| Returns: |
| Dict with 'text', 'audio' (numpy array), and 'sampling_rate' |
| """ |
| target_sr = output_sample_rate or self.output_sample_rate |
| tokens = model_outputs.get("tokens") |
|
|
| if tokens is None: |
| return { |
| "text": "", |
| "audio": np.array([], dtype=np.float32), |
| "sampling_rate": target_sr, |
| } |
|
|
| |
| if torch.is_tensor(tokens): |
| tokens = tokens.cpu() |
| if tokens.dim() > 1: |
| tokens = tokens[0] |
|
|
| |
| if hasattr(self.model, "generation_config") and self.model.generation_config is not None: |
| eos_ids = self.model.generation_config.eos_token_id |
| if eos_ids is not None: |
| eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids} |
| tokens = [t for t in tokens.tolist() if t not in eos_set] |
|
|
| text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() |
| text = strip_thinking(text) |
| text = _truncate_repetitions(text) |
|
|
| result = {"text": text} |
|
|
| |
| if not return_text_only: |
| voice = tts_voice or self.tts_voice |
| audio = self._synthesize_speech(text, voice) |
| |
| audio = self._resample_audio(audio, TTS_SAMPLE_RATE, target_sr) |
| result["audio"] = audio |
| result["sampling_rate"] = target_sr |
|
|
| return result |
|
|
| def _synthesize_speech(self, text: str, voice: str) -> np.ndarray: |
| """Synthesize speech from text using Kokoro TTS. |
| |
| Args: |
| text: Text to synthesize |
| voice: Kokoro voice ID |
| |
| Returns: |
| Audio as numpy array (float32, 24kHz native TTS rate) |
| """ |
| if not text or not text.strip(): |
| return np.array([], dtype=np.float32) |
|
|
| try: |
| audio_chunks = [] |
| for _, _, audio in self.tts_pipeline(text, voice=voice): |
| audio_chunks.append(audio) |
|
|
| if audio_chunks: |
| return np.concatenate(audio_chunks) |
| except Exception: |
| pass |
|
|
| return np.array([], dtype=np.float32) |
|
|
| def _resample_audio(self, audio: np.ndarray, from_sr: int, to_sr: int) -> np.ndarray: |
| """Resample audio to target sample rate. |
| |
| Args: |
| audio: Input audio array |
| from_sr: Source sample rate |
| to_sr: Target sample rate |
| |
| Returns: |
| Resampled audio array |
| """ |
| if len(audio) == 0 or from_sr == to_sr: |
| return audio |
|
|
| num_samples = int(len(audio) * to_sr / from_sr) |
| return scipy.signal.resample(audio, num_samples).astype(np.float32) |
|
|
| def text_to_speech( |
| self, |
| text: str, |
| voice: str | None = None, |
| output_sample_rate: int | None = None, |
| ) -> dict[str, Any]: |
| """Convert text to speech using Kokoro TTS. |
| |
| This is a convenience method for generating audio from text without |
| going through the full speech-to-speech pipeline. |
| |
| Args: |
| text: Text to synthesize |
| voice: Kokoro voice ID (default: self.tts_voice) |
| output_sample_rate: Output sample rate (default: self.output_sample_rate) |
| |
| Returns: |
| Dict with 'audio' (numpy array) and 'sampling_rate' keys |
| """ |
| voice = voice or self.tts_voice |
| target_sr = output_sample_rate or self.output_sample_rate |
| audio = self._synthesize_speech(text, voice) |
| audio = self._resample_audio(audio, TTS_SAMPLE_RATE, target_sr) |
| return {"audio": audio, "sampling_rate": target_sr} |
|
|
| def _extract_audio(self, inputs) -> np.ndarray | None: |
| """Extract audio array from various input formats. |
| |
| Args: |
| inputs: Audio input in various formats |
| |
| Returns: |
| Audio as numpy array (float32) or None if extraction fails |
| """ |
| if isinstance(inputs, dict): |
| if "array" in inputs: |
| audio = inputs["array"] |
| if isinstance(audio, np.ndarray): |
| return audio.astype(np.float32) if audio.dtype != np.float32 else audio |
| return np.array(audio, dtype=np.float32) |
| if "raw" in inputs: |
| audio = inputs["raw"] |
| if isinstance(audio, np.ndarray): |
| return audio.astype(np.float32) if audio.dtype != np.float32 else audio |
| return np.array(audio, dtype=np.float32) |
|
|
| elif isinstance(inputs, str): |
| |
| with Path(inputs).open("rb") as f: |
| return ffmpeg_read(f.read(), sampling_rate=16000) |
|
|
| elif isinstance(inputs, bytes): |
| return ffmpeg_read(inputs, sampling_rate=16000) |
|
|
| elif isinstance(inputs, np.ndarray): |
| return inputs.astype(np.float32) if inputs.dtype != np.float32 else inputs |
|
|
| return None |
|
|
| def __call__(self, inputs, **kwargs) -> dict[str, Any]: |
| """Process audio input and return speech output. |
| |
| Args: |
| inputs: Audio input (file path, dict with array, numpy array, or bytes) |
| tts_voice: Kokoro voice ID for TTS output (default: "af_heart") |
| return_text_only: If True, skip TTS and return only transcription |
| user_prompt: Custom prompt for the model |
| |
| Returns: |
| Dict with: |
| - 'text': Transcription/response text |
| - 'audio': Synthesized speech as numpy array (float32) |
| - 'sampling_rate': Audio sample rate (24000) |
| """ |
| return super().__call__(inputs, **kwargs) |
|
|