Spaces:
Running on Zero
Running on Zero
| """ | |
| Step 3: Transcribe audio with timestamps. | |
| Primary local backend (device-dependent): | |
| - Apple MPS: mlx-whisper | |
| - CUDA: faster-whisper | |
| - CPU: faster-whisper | |
| Outermost fallback: | |
| - Pollinations Whisper API (verbose_json) | |
| """ | |
| import os | |
| import requests | |
| import torch | |
| from dotenv import load_dotenv | |
| import spaces | |
| load_dotenv() | |
| POLLINATIONS_URL = "https://gen.pollinations.ai/v1/audio/transcriptions" | |
| POLLEN_TRANSCRIBE_MODEL = os.getenv("POLLEN_TRANSCRIBE_MODEL", "whisper-large-v3") | |
| MLX_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-mlx") | |
| FASTER_WHISPER_MODEL = os.getenv("FASTER_WHISPER_MODEL", "large-v3") | |
| OPENAI_WHISPER_MODEL = os.getenv("OPENAI_WHISPER_MODEL", "large-v3") | |
| LOCAL_WHISPER_BACKEND_ENV = "VIDEOVOICE_WHISPER_BACKEND" | |
| _VALID_LOCAL_BACKENDS = { | |
| "mlx-whisper", | |
| "openai-whisper-cuda", | |
| "faster-whisper-cpu", | |
| } | |
| _FASTER_WHISPER_MODELS = {} | |
| _OPENAI_WHISPER_MODEL = None | |
| def _running_on_hf_space() -> bool: | |
| return bool( | |
| os.getenv("SPACE_ID") | |
| or os.getenv("SPACE_HOST") | |
| or os.getenv("HF_SPACE_ID") | |
| ) | |
| def _get_local_whisper_backend() -> str: | |
| """ | |
| Resolve the local transcription backend lazily. | |
| On HF Spaces, default to CPU faster-whisper unless explicitly overridden. | |
| ZeroGPU can report CUDA availability outside an active @spaces.GPU call, | |
| which makes import-time backend selection unreliable. | |
| """ | |
| override = os.getenv(LOCAL_WHISPER_BACKEND_ENV, "").strip().lower() | |
| if override: | |
| if override not in _VALID_LOCAL_BACKENDS: | |
| raise ValueError( | |
| f"Invalid {LOCAL_WHISPER_BACKEND_ENV}={override!r}. " | |
| f"Expected one of: {', '.join(sorted(_VALID_LOCAL_BACKENDS))}." | |
| ) | |
| return override | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mlx-whisper" | |
| if _running_on_hf_space(): | |
| return "faster-whisper-cpu" | |
| if torch.cuda.is_available(): | |
| # PyTorch-based path so @spaces.GPU can intercept the CUDA allocation. | |
| # faster-whisper uses CTranslate2 which bypasses PyTorch and breaks ZeroGPU. | |
| return "openai-whisper-cuda" | |
| return "faster-whisper-cpu" | |
| def _extract_words(raw_words: list[dict]) -> list[dict]: | |
| """Normalise word timestamps into {word, start, end}.""" | |
| output = [] | |
| for raw in raw_words or []: | |
| start = raw.get("start") | |
| end = raw.get("end") | |
| if start is None or end is None: | |
| continue | |
| output.append( | |
| { | |
| "word": str(raw.get("word", "")).strip(), | |
| "start": float(start), | |
| "end": float(end), | |
| } | |
| ) | |
| return output | |
| def _normalise_segments(segments: list[dict]) -> list[dict]: | |
| """Return canonical segment schema with word-level timestamps.""" | |
| output = [] | |
| for seg in segments: | |
| start = seg.get("start") | |
| end = seg.get("end") | |
| if start is None or end is None: | |
| continue | |
| words = _extract_words(seg.get("words", [])) | |
| output.append( | |
| { | |
| "start": float(start), | |
| "end": float(end), | |
| "text": str(seg.get("text", "")).strip(), | |
| "words": words, | |
| } | |
| ) | |
| return output | |
| # Max duration (seconds) before a segment is considered oversized and needs splitting. | |
| _MAX_SEGMENT_DURATION = 15.0 | |
| # Preferred pause gap (seconds) to use as a split point. | |
| _PAUSE_THRESHOLD = 0.4 | |
| def _split_oversized_segments(segments: list[dict]) -> list[dict]: | |
| """Split segments longer than _MAX_SEGMENT_DURATION using word timings.""" | |
| output = [] | |
| for seg in segments: | |
| duration = seg["end"] - seg["start"] | |
| words = seg.get("words", []) | |
| real_words = [w for w in words if w["word"]] | |
| if duration <= _MAX_SEGMENT_DURATION or len(real_words) < 2: | |
| output.append(seg) | |
| continue | |
| chunks = [] | |
| chunk_start_idx = 0 | |
| chunk_start_time = real_words[0]["start"] | |
| for i in range(len(real_words) - 1): | |
| elapsed = real_words[i]["end"] - chunk_start_time | |
| gap = real_words[i + 1]["start"] - real_words[i]["end"] | |
| should_split = ( | |
| (elapsed >= _MAX_SEGMENT_DURATION and gap >= 0.15) | |
| or (elapsed >= _MAX_SEGMENT_DURATION * 0.5 and gap >= _PAUSE_THRESHOLD) | |
| ) | |
| if should_split: | |
| chunks.append(real_words[chunk_start_idx : i + 1]) | |
| chunk_start_idx = i + 1 | |
| chunk_start_time = real_words[i + 1]["start"] | |
| if chunk_start_idx < len(real_words): | |
| chunks.append(real_words[chunk_start_idx:]) | |
| for chunk_words in chunks: | |
| output.append( | |
| { | |
| "start": chunk_words[0]["start"], | |
| "end": chunk_words[-1]["end"], | |
| "text": " ".join(w["word"] for w in chunk_words).strip(), | |
| "words": chunk_words, | |
| } | |
| ) | |
| return output | |
| def _assign_words_to_segments(segments: list[dict], words: list[dict]) -> None: | |
| """Distribute top-level word list into segments by timestamp overlap.""" | |
| normalised = _extract_words(words) | |
| for seg in segments: | |
| seg["words"] = [ | |
| w for w in normalised if w["start"] >= seg["start"] and w["end"] <= seg["end"] | |
| ] | |
| def _segments_from_pollinations(audio_path: str, language: str) -> list[dict]: | |
| """Call Pollinations Whisper API and return canonical segments.""" | |
| api_key = ( | |
| os.getenv("POLLEN_API_KEY_SECONDARY") | |
| or os.getenv("POLLEN_API_KEY") | |
| or os.getenv("POLLINATIONS_API_KEY", "") | |
| ) | |
| headers = {"Authorization": f"Bearer {api_key}"} | |
| with open(audio_path, "rb") as audio_file: | |
| files = {"file": (os.path.basename(audio_path), audio_file, "audio/wav")} | |
| # When the caller passes "auto" (or empty), omit the `language` field so | |
| # Whisper auto-detects. Forcing a wrong language code makes Whisper | |
| # silently switch to translate-mode (e.g. Hindi audio + language="en" | |
| # produces an English translation, not a Hindi transcript). | |
| data = { | |
| "model": POLLEN_TRANSCRIBE_MODEL, | |
| "response_format": "verbose_json", | |
| "temperature": 0, | |
| "timestamp_granularities[]": "word", | |
| } | |
| if language and language.lower() not in ("auto", ""): | |
| data["language"] = language | |
| response = requests.post( | |
| POLLINATIONS_URL, | |
| headers=headers, | |
| files=files, | |
| data=data, | |
| timeout=120, | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| segments = _normalise_segments(result.get("segments", [])) | |
| if not any(seg.get("words") for seg in segments) and "words" in result: | |
| _assign_words_to_segments(segments, result["words"]) | |
| return _normalise_segments(segments) | |
| def _segments_from_mlx(audio_path: str, language: str) -> list[dict]: | |
| """Run mlx-whisper locally.""" | |
| print("[s2] Using mlx-whisper backend...") | |
| try: | |
| import mlx_whisper | |
| except ImportError: | |
| raise ImportError("mlx-whisper is not installed. Run: uv add mlx-whisper") | |
| result = mlx_whisper.transcribe( | |
| audio_path, | |
| path_or_hf_repo=MLX_MODEL, | |
| language=language if language != "auto" else None, | |
| word_timestamps=True, | |
| ) | |
| return _normalise_segments(result.get("segments", [])) | |
| def _get_faster_whisper_model(device: str, compute_type: str): | |
| """Load/cached faster-whisper model.""" | |
| from faster_whisper import WhisperModel | |
| key = (device, compute_type) | |
| if key not in _FASTER_WHISPER_MODELS: | |
| _FASTER_WHISPER_MODELS[key] = WhisperModel( | |
| FASTER_WHISPER_MODEL, | |
| device=device, | |
| compute_type=compute_type, | |
| ) | |
| return _FASTER_WHISPER_MODELS[key] | |
| def _segments_from_faster_whisper_impl( | |
| audio_path: str, | |
| language: str, | |
| device: str, | |
| compute_type: str, | |
| ) -> list[dict]: | |
| model = _get_faster_whisper_model(device=device, compute_type=compute_type) | |
| segments, _ = model.transcribe( | |
| audio_path, | |
| language=None if language == "auto" else language, | |
| word_timestamps=True, | |
| ) | |
| output = [] | |
| for seg in segments: | |
| words = [] | |
| for word in seg.words or []: | |
| if word.start is None or word.end is None: | |
| continue | |
| words.append( | |
| { | |
| "word": str(word.word or "").strip(), | |
| "start": float(word.start), | |
| "end": float(word.end), | |
| } | |
| ) | |
| output.append( | |
| { | |
| "start": float(seg.start), | |
| "end": float(seg.end), | |
| "text": str(seg.text or "").strip(), | |
| "words": words, | |
| } | |
| ) | |
| return output | |
| def _segments_from_faster_whisper_cpu( | |
| audio_path: str, | |
| language: str, | |
| ) -> list[dict]: | |
| """CPU-only faster-whisper (no GPU decorator — runs outside ZeroGPU budget).""" | |
| return _segments_from_faster_whisper_impl(audio_path, language, "cpu", "int8") | |
| def _get_openai_whisper_model(): | |
| """Load openai-whisper once per process. CUDA if available.""" | |
| global _OPENAI_WHISPER_MODEL | |
| if _OPENAI_WHISPER_MODEL is None: | |
| try: | |
| import whisper as openai_whisper | |
| except ImportError as exc: | |
| raise ImportError("openai-whisper is not installed") from exc | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[s2] Loading openai-whisper ({OPENAI_WHISPER_MODEL}) on {device}...") | |
| _OPENAI_WHISPER_MODEL = openai_whisper.load_model(OPENAI_WHISPER_MODEL, device=device) | |
| return _OPENAI_WHISPER_MODEL | |
| def _segments_from_openai_whisper( | |
| audio_path: str, | |
| language: str, | |
| ) -> list[dict]: | |
| """GPU-decorated openai-whisper execution (PyTorch-native, ZeroGPU-compatible).""" | |
| model = _get_openai_whisper_model() | |
| result = model.transcribe( | |
| audio_path, | |
| language=None if language == "auto" else language, | |
| word_timestamps=True, | |
| verbose=False, | |
| ) | |
| return _normalise_segments(result.get("segments", [])) | |
| def _segments_from_local_backend(audio_path: str, language: str) -> list[dict]: | |
| """Dispatch local whisper backend from runtime device detection.""" | |
| backend = _get_local_whisper_backend() | |
| if backend == "mlx-whisper": | |
| return _segments_from_mlx(audio_path, language) | |
| if backend == "openai-whisper-cuda": | |
| print("[s2] Using openai-whisper backend (cuda)...") | |
| try: | |
| return _segments_from_openai_whisper(audio_path, language) | |
| except ImportError: | |
| print("[s2] openai-whisper unavailable; falling back to faster-whisper (cpu).") | |
| return _segments_from_faster_whisper_cpu(audio_path, language) | |
| print("[s2] Using faster-whisper backend (cpu)...") | |
| return _segments_from_faster_whisper_cpu(audio_path, language) | |
| def transcribe(audio_path: str, language: str = "en") -> list[dict]: | |
| """ | |
| Transcribe audio and return canonical segment schema. | |
| Priority: | |
| 1. Pollinations API (fast, offloads computation) | |
| 2. Local backend (GPU/MPS if available, otherwise CPU) | |
| """ | |
| print(f"[s2] Transcribing {audio_path} (lang={language})...") | |
| segments = None | |
| pollinations_error = None | |
| local_error = None | |
| # 1. Try Pollinations API first | |
| try: | |
| print("[s2] Trying Pollinations API...") | |
| segments = _segments_from_pollinations(audio_path, language) | |
| if segments: | |
| print(f"[s2] Pollinations returned {len(segments)} segments ✓") | |
| else: | |
| segments = None | |
| except Exception as exc: | |
| print(f"[s2] Pollinations error ({exc}) — falling back to local backend.") | |
| pollinations_error = exc | |
| segments = None | |
| # 2. Try Local Backend (GPU or CPU) | |
| if segments is None: | |
| try: | |
| backend = _get_local_whisper_backend() | |
| print(f"[s2] Trying local backend ({backend})...") | |
| segments = _segments_from_local_backend(audio_path, language) | |
| if segments: | |
| print(f"[s2] Local backend returned {len(segments)} segments ✓") | |
| except Exception as exc: | |
| print(f"[s2] Local backend error ({exc}).") | |
| local_error = exc | |
| segments = None | |
| if segments is None: | |
| details = [] | |
| if pollinations_error is not None: | |
| details.append(f"Pollinations: {pollinations_error}") | |
| if local_error is not None: | |
| details.append(f"Local backend: {local_error}") | |
| suffix = f" Details: {' | '.join(details)}" if details else "" | |
| raise RuntimeError(f"Transcription failed on all available backends.{suffix}") | |
| before = len(segments) | |
| segments = _split_oversized_segments(segments) | |
| if len(segments) != before: | |
| print(f"[s2] Split {before} oversized segment(s) → {len(segments)} segments") | |
| return _normalise_segments(segments) | |