| """ASR pipeline for audio-to-text transcription with optional timestamps and diarization.""" |
|
|
| import re |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| import transformers |
|
|
| try: |
| from .alignment import ForcedAligner |
| from .asr_modeling import ASRModel |
| from .diarization import SpeakerDiarizer |
| except ImportError: |
| from alignment import ForcedAligner |
| from asr_modeling import ASRModel |
| from diarization import SpeakerDiarizer |
|
|
| |
| __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"] |
|
|
|
|
| class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline): |
| """ASR Pipeline for audio-to-text transcription.""" |
|
|
| model: ASRModel |
|
|
| def __init__(self, model: ASRModel, **kwargs): |
| """Initialize ASR pipeline. |
| |
| Args: |
| model: ASRModel instance for transcription |
| **kwargs: Additional arguments (feature_extractor, tokenizer, device) |
| """ |
| feature_extractor = kwargs.pop("feature_extractor", None) |
| tokenizer = kwargs.pop("tokenizer", model.tokenizer) |
|
|
| if feature_extractor is None: |
| feature_extractor = model.get_processor().feature_extractor |
|
|
| super().__init__( |
| model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs |
| ) |
| self._current_audio = None |
|
|
| def _sanitize_parameters(self, **kwargs): |
| """Intercept our custom parameters before parent class validates them.""" |
| |
| kwargs.pop("return_timestamps", None) |
| kwargs.pop("return_speakers", None) |
| kwargs.pop("num_speakers", None) |
| kwargs.pop("min_speakers", None) |
| kwargs.pop("max_speakers", None) |
| kwargs.pop("hf_token", None) |
| kwargs.pop("user_prompt", None) |
| kwargs.pop("diarization_backend", None) |
|
|
| return super()._sanitize_parameters(**kwargs) |
|
|
| def __call__( |
| self, |
| inputs, |
| **kwargs, |
| ): |
| """Transcribe audio with optional word-level timestamps and speaker diarization. |
| |
| Args: |
| inputs: Audio input (file path, dict with array/sampling_rate, etc.) |
| return_timestamps: If True, return word-level timestamps using forced alignment |
| return_speakers: If True, return speaker labels for each word |
| user_prompt: Custom transcription prompt (default: "Transcribe: ") |
| num_speakers: Exact number of speakers (if known, for diarization) |
| min_speakers: Minimum number of speakers (for diarization) |
| max_speakers: Maximum number of speakers (for diarization) |
| **kwargs: Additional arguments passed to the pipeline |
| |
| Returns: |
| Dict with 'text' key, 'words' key if return_timestamps=True, |
| and speaker labels on words if return_speakers=True |
| """ |
| |
| return_timestamps = kwargs.pop("return_timestamps", False) |
| return_speakers = kwargs.pop("return_speakers", False) |
| user_prompt = kwargs.pop("user_prompt", None) |
| diarization_params = { |
| "num_speakers": kwargs.pop("num_speakers", None), |
| "min_speakers": kwargs.pop("min_speakers", None), |
| "max_speakers": kwargs.pop("max_speakers", None), |
| } |
|
|
| if return_speakers: |
| return_timestamps = True |
|
|
| |
| original_prompt = None |
| if user_prompt: |
| original_prompt = self.model.TRANSCRIBE_PROMPT |
| self.model.TRANSCRIBE_PROMPT = user_prompt |
|
|
| |
| if return_timestamps or return_speakers: |
| self._current_audio = self._extract_audio(inputs) |
|
|
| |
| result = super().__call__(inputs, **kwargs) |
|
|
| |
| if return_timestamps and self._current_audio is not None: |
| text = result.get("text", "") |
| if text: |
| try: |
| words = ForcedAligner.align( |
| self._current_audio["array"], |
| text, |
| sample_rate=self._current_audio.get("sampling_rate", 16000), |
| ) |
| result["words"] = words |
| except Exception as e: |
| result["words"] = [] |
| result["timestamp_error"] = str(e) |
| else: |
| result["words"] = [] |
|
|
| |
| if return_speakers and self._current_audio is not None: |
| try: |
| |
| speaker_segments = SpeakerDiarizer.diarize( |
| self._current_audio["array"], |
| sample_rate=self._current_audio.get("sampling_rate", 16000), |
| **{k: v for k, v in diarization_params.items() if v is not None}, |
| ) |
| result["speaker_segments"] = speaker_segments |
|
|
| |
| if result.get("words"): |
| result["words"] = SpeakerDiarizer.assign_speakers_to_words( |
| result["words"], |
| speaker_segments, |
| ) |
| except Exception as e: |
| result["speaker_segments"] = [] |
| result["diarization_error"] = str(e) |
|
|
| |
| self._current_audio = None |
| if original_prompt is not None: |
| self.model.TRANSCRIBE_PROMPT = original_prompt |
|
|
| return result |
|
|
| def _extract_audio(self, inputs) -> dict | None: |
| """Extract audio array from various input formats using HF utilities.""" |
| from transformers.pipelines.audio_utils import ffmpeg_read |
|
|
| if isinstance(inputs, dict): |
| if "array" in inputs: |
| return { |
| "array": inputs["array"], |
| "sampling_rate": inputs.get("sampling_rate", 16000), |
| } |
| if "raw" in inputs: |
| return { |
| "array": inputs["raw"], |
| "sampling_rate": inputs.get("sampling_rate", 16000), |
| } |
| elif isinstance(inputs, str): |
| |
| with Path(inputs).open("rb") as f: |
| audio = ffmpeg_read(f.read(), sampling_rate=16000) |
| return {"array": audio, "sampling_rate": 16000} |
| elif isinstance(inputs, bytes): |
| audio = ffmpeg_read(inputs, sampling_rate=16000) |
| return {"array": audio, "sampling_rate": 16000} |
| elif isinstance(inputs, np.ndarray): |
| return {"array": inputs, "sampling_rate": 16000} |
|
|
| return None |
|
|
| def preprocess(self, inputs, **preprocess_params): |
| """Preprocess audio inputs for the model. |
| |
| Args: |
| inputs: Audio input (dict with array, file path, etc.) |
| **preprocess_params: Additional preprocessing parameters |
| |
| Yields: |
| Model input dicts with input_features and attention_mask |
| """ |
| |
| if isinstance(inputs, dict) and "array" in inputs: |
| inputs = { |
| "raw": inputs["array"], |
| "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate), |
| } |
|
|
| for item in super().preprocess(inputs, **preprocess_params): |
| if "is_last" not in item: |
| item["is_last"] = True |
| yield item |
|
|
| def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]: |
| """Run model forward pass to generate transcription. |
| |
| Args: |
| model_inputs: Dict with input_features and attention_mask |
| **generate_kwargs: Generation parameters |
| |
| Returns: |
| Dict with generated token IDs |
| """ |
| |
| is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True |
|
|
| input_features = model_inputs["input_features"].to(self.model.device) |
| audio_attention_mask = model_inputs["attention_mask"].to(self.model.device) |
|
|
| generated_ids = self.model.generate( |
| input_features=input_features, |
| audio_attention_mask=audio_attention_mask, |
| **generate_kwargs, |
| ) |
|
|
| return {"tokens": generated_ids, "is_last": is_last} |
|
|
| def postprocess(self, model_outputs, **kwargs) -> dict[str, str]: |
| """Convert model output tokens to text. |
| |
| Args: |
| model_outputs: Dict with 'tokens' key containing generated IDs |
| **kwargs: Additional postprocessing parameters |
| |
| Returns: |
| Dict with 'text' key containing transcription |
| """ |
| |
| if isinstance(model_outputs, list): |
| model_outputs = model_outputs[0] if model_outputs else {} |
|
|
| tokens = model_outputs.get("tokens") |
| if tokens is None: |
| return super().postprocess(model_outputs, **kwargs) |
|
|
| if torch.is_tensor(tokens): |
| tokens = tokens.cpu() |
| if tokens.dim() > 1: |
| tokens = tokens[0] |
|
|
| |
| |
| if hasattr(self, "model") and hasattr(self.model, "generation_config"): |
| eos_ids = self.model.generation_config.eos_token_id |
| if eos_ids is not None: |
| eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids} |
| tokens = [t for t in tokens.tolist() if t not in eos_set] |
|
|
| text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() |
| |
| text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip() |
| |
| text = _truncate_repetitions(text) |
| return {"text": text} |
|
|
|
|
| def _truncate_repetitions(text: str, min_repeats: int = 3) -> str: |
| """Truncate repeated words/phrases/characters at end of text. |
| |
| Detects patterns like: |
| - Repeated words: "the the the the" -> "the" |
| - Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry" |
| - Repeated characters: "444444" -> "4" |
| |
| Args: |
| text: Input text to process |
| min_repeats: Minimum repetitions to trigger truncation (default 3) |
| |
| Returns: |
| Text with trailing repetitions removed |
| """ |
| if not text: |
| return text |
|
|
| |
| char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$") |
| text = char_pattern.sub(r"\1", text) |
|
|
| |
| word_pattern = re.compile( |
| r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE |
| ) |
| while word_pattern.search(text): |
| text = word_pattern.sub(r"\1", text) |
|
|
| |
| |
| words = text.split() |
| if len(words) >= min_repeats * 2: |
| |
| for phrase_len in range(2, min(21, len(words) // min_repeats + 1)): |
| |
| phrase = " ".join(words[-phrase_len:]) |
| |
| phrase_escaped = re.escape(phrase) |
| phrase_pattern = re.compile( |
| r"(^|.*?\s)(" |
| + phrase_escaped |
| + r")(?:\s+" |
| + phrase_escaped |
| + r"){" |
| + str(min_repeats - 1) |
| + r",}\s*$", |
| re.IGNORECASE, |
| ) |
| match = phrase_pattern.match(text) |
| if match: |
| |
| text = (match.group(1) + match.group(2)).strip() |
| words = text.split() |
| break |
|
|
| return text |
|
|