""" Step 3: Transcribe audio with timestamps. Primary local backend (device-dependent): - Apple MPS: mlx-whisper - CUDA: faster-whisper - CPU: faster-whisper Outermost fallback: - Pollinations Whisper API (verbose_json) """ import os import requests import torch from dotenv import load_dotenv import spaces load_dotenv() POLLINATIONS_URL = "https://gen.pollinations.ai/v1/audio/transcriptions" POLLEN_TRANSCRIBE_MODEL = os.getenv("POLLEN_TRANSCRIBE_MODEL", "whisper-large-v3") MLX_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-mlx") FASTER_WHISPER_MODEL = os.getenv("FASTER_WHISPER_MODEL", "large-v3") OPENAI_WHISPER_MODEL = os.getenv("OPENAI_WHISPER_MODEL", "large-v3") LOCAL_WHISPER_BACKEND_ENV = "VIDEOVOICE_WHISPER_BACKEND" _VALID_LOCAL_BACKENDS = { "mlx-whisper", "openai-whisper-cuda", "faster-whisper-cpu", } _FASTER_WHISPER_MODELS = {} _OPENAI_WHISPER_MODEL = None def _running_on_hf_space() -> bool: return bool( os.getenv("SPACE_ID") or os.getenv("SPACE_HOST") or os.getenv("HF_SPACE_ID") ) def _get_local_whisper_backend() -> str: """ Resolve the local transcription backend lazily. On HF Spaces, default to CPU faster-whisper unless explicitly overridden. ZeroGPU can report CUDA availability outside an active @spaces.GPU call, which makes import-time backend selection unreliable. """ override = os.getenv(LOCAL_WHISPER_BACKEND_ENV, "").strip().lower() if override: if override not in _VALID_LOCAL_BACKENDS: raise ValueError( f"Invalid {LOCAL_WHISPER_BACKEND_ENV}={override!r}. " f"Expected one of: {', '.join(sorted(_VALID_LOCAL_BACKENDS))}." ) return override if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mlx-whisper" if _running_on_hf_space(): return "faster-whisper-cpu" if torch.cuda.is_available(): # PyTorch-based path so @spaces.GPU can intercept the CUDA allocation. # faster-whisper uses CTranslate2 which bypasses PyTorch and breaks ZeroGPU. return "openai-whisper-cuda" return "faster-whisper-cpu" def _extract_words(raw_words: list[dict]) -> list[dict]: """Normalise word timestamps into {word, start, end}.""" output = [] for raw in raw_words or []: start = raw.get("start") end = raw.get("end") if start is None or end is None: continue output.append( { "word": str(raw.get("word", "")).strip(), "start": float(start), "end": float(end), } ) return output def _normalise_segments(segments: list[dict]) -> list[dict]: """Return canonical segment schema with word-level timestamps.""" output = [] for seg in segments: start = seg.get("start") end = seg.get("end") if start is None or end is None: continue words = _extract_words(seg.get("words", [])) output.append( { "start": float(start), "end": float(end), "text": str(seg.get("text", "")).strip(), "words": words, } ) return output # Max duration (seconds) before a segment is considered oversized and needs splitting. _MAX_SEGMENT_DURATION = 15.0 # Preferred pause gap (seconds) to use as a split point. _PAUSE_THRESHOLD = 0.4 def _split_oversized_segments(segments: list[dict]) -> list[dict]: """Split segments longer than _MAX_SEGMENT_DURATION using word timings.""" output = [] for seg in segments: duration = seg["end"] - seg["start"] words = seg.get("words", []) real_words = [w for w in words if w["word"]] if duration <= _MAX_SEGMENT_DURATION or len(real_words) < 2: output.append(seg) continue chunks = [] chunk_start_idx = 0 chunk_start_time = real_words[0]["start"] for i in range(len(real_words) - 1): elapsed = real_words[i]["end"] - chunk_start_time gap = real_words[i + 1]["start"] - real_words[i]["end"] should_split = ( (elapsed >= _MAX_SEGMENT_DURATION and gap >= 0.15) or (elapsed >= _MAX_SEGMENT_DURATION * 0.5 and gap >= _PAUSE_THRESHOLD) ) if should_split: chunks.append(real_words[chunk_start_idx : i + 1]) chunk_start_idx = i + 1 chunk_start_time = real_words[i + 1]["start"] if chunk_start_idx < len(real_words): chunks.append(real_words[chunk_start_idx:]) for chunk_words in chunks: output.append( { "start": chunk_words[0]["start"], "end": chunk_words[-1]["end"], "text": " ".join(w["word"] for w in chunk_words).strip(), "words": chunk_words, } ) return output def _assign_words_to_segments(segments: list[dict], words: list[dict]) -> None: """Distribute top-level word list into segments by timestamp overlap.""" normalised = _extract_words(words) for seg in segments: seg["words"] = [ w for w in normalised if w["start"] >= seg["start"] and w["end"] <= seg["end"] ] def _segments_from_pollinations(audio_path: str, language: str) -> list[dict]: """Call Pollinations Whisper API and return canonical segments.""" api_key = ( os.getenv("POLLEN_API_KEY_SECONDARY") or os.getenv("POLLEN_API_KEY") or os.getenv("POLLINATIONS_API_KEY", "") ) headers = {"Authorization": f"Bearer {api_key}"} with open(audio_path, "rb") as audio_file: files = {"file": (os.path.basename(audio_path), audio_file, "audio/wav")} # When the caller passes "auto" (or empty), omit the `language` field so # Whisper auto-detects. Forcing a wrong language code makes Whisper # silently switch to translate-mode (e.g. Hindi audio + language="en" # produces an English translation, not a Hindi transcript). data = { "model": POLLEN_TRANSCRIBE_MODEL, "response_format": "verbose_json", "temperature": 0, "timestamp_granularities[]": "word", } if language and language.lower() not in ("auto", ""): data["language"] = language response = requests.post( POLLINATIONS_URL, headers=headers, files=files, data=data, timeout=120, ) response.raise_for_status() result = response.json() segments = _normalise_segments(result.get("segments", [])) if not any(seg.get("words") for seg in segments) and "words" in result: _assign_words_to_segments(segments, result["words"]) return _normalise_segments(segments) def _segments_from_mlx(audio_path: str, language: str) -> list[dict]: """Run mlx-whisper locally.""" print("[s2] Using mlx-whisper backend...") try: import mlx_whisper except ImportError: raise ImportError("mlx-whisper is not installed. Run: uv add mlx-whisper") result = mlx_whisper.transcribe( audio_path, path_or_hf_repo=MLX_MODEL, language=language if language != "auto" else None, word_timestamps=True, ) return _normalise_segments(result.get("segments", [])) def _get_faster_whisper_model(device: str, compute_type: str): """Load/cached faster-whisper model.""" from faster_whisper import WhisperModel key = (device, compute_type) if key not in _FASTER_WHISPER_MODELS: _FASTER_WHISPER_MODELS[key] = WhisperModel( FASTER_WHISPER_MODEL, device=device, compute_type=compute_type, ) return _FASTER_WHISPER_MODELS[key] def _segments_from_faster_whisper_impl( audio_path: str, language: str, device: str, compute_type: str, ) -> list[dict]: model = _get_faster_whisper_model(device=device, compute_type=compute_type) segments, _ = model.transcribe( audio_path, language=None if language == "auto" else language, word_timestamps=True, ) output = [] for seg in segments: words = [] for word in seg.words or []: if word.start is None or word.end is None: continue words.append( { "word": str(word.word or "").strip(), "start": float(word.start), "end": float(word.end), } ) output.append( { "start": float(seg.start), "end": float(seg.end), "text": str(seg.text or "").strip(), "words": words, } ) return output def _segments_from_faster_whisper_cpu( audio_path: str, language: str, ) -> list[dict]: """CPU-only faster-whisper (no GPU decorator — runs outside ZeroGPU budget).""" return _segments_from_faster_whisper_impl(audio_path, language, "cpu", "int8") def _get_openai_whisper_model(): """Load openai-whisper once per process. CUDA if available.""" global _OPENAI_WHISPER_MODEL if _OPENAI_WHISPER_MODEL is None: try: import whisper as openai_whisper except ImportError as exc: raise ImportError("openai-whisper is not installed") from exc device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[s2] Loading openai-whisper ({OPENAI_WHISPER_MODEL}) on {device}...") _OPENAI_WHISPER_MODEL = openai_whisper.load_model(OPENAI_WHISPER_MODEL, device=device) return _OPENAI_WHISPER_MODEL @spaces.GPU(duration=60) def _segments_from_openai_whisper( audio_path: str, language: str, ) -> list[dict]: """GPU-decorated openai-whisper execution (PyTorch-native, ZeroGPU-compatible).""" model = _get_openai_whisper_model() result = model.transcribe( audio_path, language=None if language == "auto" else language, word_timestamps=True, verbose=False, ) return _normalise_segments(result.get("segments", [])) def _segments_from_local_backend(audio_path: str, language: str) -> list[dict]: """Dispatch local whisper backend from runtime device detection.""" backend = _get_local_whisper_backend() if backend == "mlx-whisper": return _segments_from_mlx(audio_path, language) if backend == "openai-whisper-cuda": print("[s2] Using openai-whisper backend (cuda)...") try: return _segments_from_openai_whisper(audio_path, language) except ImportError: print("[s2] openai-whisper unavailable; falling back to faster-whisper (cpu).") return _segments_from_faster_whisper_cpu(audio_path, language) print("[s2] Using faster-whisper backend (cpu)...") return _segments_from_faster_whisper_cpu(audio_path, language) def transcribe(audio_path: str, language: str = "en") -> list[dict]: """ Transcribe audio and return canonical segment schema. Priority: 1. Pollinations API (fast, offloads computation) 2. Local backend (GPU/MPS if available, otherwise CPU) """ print(f"[s2] Transcribing {audio_path} (lang={language})...") segments = None pollinations_error = None local_error = None # 1. Try Pollinations API first try: print("[s2] Trying Pollinations API...") segments = _segments_from_pollinations(audio_path, language) if segments: print(f"[s2] Pollinations returned {len(segments)} segments ✓") else: segments = None except Exception as exc: print(f"[s2] Pollinations error ({exc}) — falling back to local backend.") pollinations_error = exc segments = None # 2. Try Local Backend (GPU or CPU) if segments is None: try: backend = _get_local_whisper_backend() print(f"[s2] Trying local backend ({backend})...") segments = _segments_from_local_backend(audio_path, language) if segments: print(f"[s2] Local backend returned {len(segments)} segments ✓") except Exception as exc: print(f"[s2] Local backend error ({exc}).") local_error = exc segments = None if segments is None: details = [] if pollinations_error is not None: details.append(f"Pollinations: {pollinations_error}") if local_error is not None: details.append(f"Local backend: {local_error}") suffix = f" Details: {' | '.join(details)}" if details else "" raise RuntimeError(f"Transcription failed on all available backends.{suffix}") before = len(segments) segments = _split_oversized_segments(segments) if len(segments) != before: print(f"[s2] Split {before} oversized segment(s) → {len(segments)} segments") return _normalise_segments(segments)