videovoice / steps /s2_transcribe.py
github-actions[bot]
deploy: switch to chatterbox requirements @ 21354c9
80f0ab9
"""
Step 3: Transcribe audio with timestamps.
Primary local backend (device-dependent):
- Apple MPS: mlx-whisper
- CUDA: faster-whisper
- CPU: faster-whisper
Outermost fallback:
- Pollinations Whisper API (verbose_json)
"""
import os
import requests
import torch
from dotenv import load_dotenv
import spaces
load_dotenv()
POLLINATIONS_URL = "https://gen.pollinations.ai/v1/audio/transcriptions"
POLLEN_TRANSCRIBE_MODEL = os.getenv("POLLEN_TRANSCRIBE_MODEL", "whisper-large-v3")
MLX_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-mlx")
FASTER_WHISPER_MODEL = os.getenv("FASTER_WHISPER_MODEL", "large-v3")
OPENAI_WHISPER_MODEL = os.getenv("OPENAI_WHISPER_MODEL", "large-v3")
LOCAL_WHISPER_BACKEND_ENV = "VIDEOVOICE_WHISPER_BACKEND"
_VALID_LOCAL_BACKENDS = {
"mlx-whisper",
"openai-whisper-cuda",
"faster-whisper-cpu",
}
_FASTER_WHISPER_MODELS = {}
_OPENAI_WHISPER_MODEL = None
def _running_on_hf_space() -> bool:
return bool(
os.getenv("SPACE_ID")
or os.getenv("SPACE_HOST")
or os.getenv("HF_SPACE_ID")
)
def _get_local_whisper_backend() -> str:
"""
Resolve the local transcription backend lazily.
On HF Spaces, default to CPU faster-whisper unless explicitly overridden.
ZeroGPU can report CUDA availability outside an active @spaces.GPU call,
which makes import-time backend selection unreliable.
"""
override = os.getenv(LOCAL_WHISPER_BACKEND_ENV, "").strip().lower()
if override:
if override not in _VALID_LOCAL_BACKENDS:
raise ValueError(
f"Invalid {LOCAL_WHISPER_BACKEND_ENV}={override!r}. "
f"Expected one of: {', '.join(sorted(_VALID_LOCAL_BACKENDS))}."
)
return override
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mlx-whisper"
if _running_on_hf_space():
return "faster-whisper-cpu"
if torch.cuda.is_available():
# PyTorch-based path so @spaces.GPU can intercept the CUDA allocation.
# faster-whisper uses CTranslate2 which bypasses PyTorch and breaks ZeroGPU.
return "openai-whisper-cuda"
return "faster-whisper-cpu"
def _extract_words(raw_words: list[dict]) -> list[dict]:
"""Normalise word timestamps into {word, start, end}."""
output = []
for raw in raw_words or []:
start = raw.get("start")
end = raw.get("end")
if start is None or end is None:
continue
output.append(
{
"word": str(raw.get("word", "")).strip(),
"start": float(start),
"end": float(end),
}
)
return output
def _normalise_segments(segments: list[dict]) -> list[dict]:
"""Return canonical segment schema with word-level timestamps."""
output = []
for seg in segments:
start = seg.get("start")
end = seg.get("end")
if start is None or end is None:
continue
words = _extract_words(seg.get("words", []))
output.append(
{
"start": float(start),
"end": float(end),
"text": str(seg.get("text", "")).strip(),
"words": words,
}
)
return output
# Max duration (seconds) before a segment is considered oversized and needs splitting.
_MAX_SEGMENT_DURATION = 15.0
# Preferred pause gap (seconds) to use as a split point.
_PAUSE_THRESHOLD = 0.4
def _split_oversized_segments(segments: list[dict]) -> list[dict]:
"""Split segments longer than _MAX_SEGMENT_DURATION using word timings."""
output = []
for seg in segments:
duration = seg["end"] - seg["start"]
words = seg.get("words", [])
real_words = [w for w in words if w["word"]]
if duration <= _MAX_SEGMENT_DURATION or len(real_words) < 2:
output.append(seg)
continue
chunks = []
chunk_start_idx = 0
chunk_start_time = real_words[0]["start"]
for i in range(len(real_words) - 1):
elapsed = real_words[i]["end"] - chunk_start_time
gap = real_words[i + 1]["start"] - real_words[i]["end"]
should_split = (
(elapsed >= _MAX_SEGMENT_DURATION and gap >= 0.15)
or (elapsed >= _MAX_SEGMENT_DURATION * 0.5 and gap >= _PAUSE_THRESHOLD)
)
if should_split:
chunks.append(real_words[chunk_start_idx : i + 1])
chunk_start_idx = i + 1
chunk_start_time = real_words[i + 1]["start"]
if chunk_start_idx < len(real_words):
chunks.append(real_words[chunk_start_idx:])
for chunk_words in chunks:
output.append(
{
"start": chunk_words[0]["start"],
"end": chunk_words[-1]["end"],
"text": " ".join(w["word"] for w in chunk_words).strip(),
"words": chunk_words,
}
)
return output
def _assign_words_to_segments(segments: list[dict], words: list[dict]) -> None:
"""Distribute top-level word list into segments by timestamp overlap."""
normalised = _extract_words(words)
for seg in segments:
seg["words"] = [
w for w in normalised if w["start"] >= seg["start"] and w["end"] <= seg["end"]
]
def _segments_from_pollinations(audio_path: str, language: str) -> list[dict]:
"""Call Pollinations Whisper API and return canonical segments."""
api_key = (
os.getenv("POLLEN_API_KEY_SECONDARY")
or os.getenv("POLLEN_API_KEY")
or os.getenv("POLLINATIONS_API_KEY", "")
)
headers = {"Authorization": f"Bearer {api_key}"}
with open(audio_path, "rb") as audio_file:
files = {"file": (os.path.basename(audio_path), audio_file, "audio/wav")}
# When the caller passes "auto" (or empty), omit the `language` field so
# Whisper auto-detects. Forcing a wrong language code makes Whisper
# silently switch to translate-mode (e.g. Hindi audio + language="en"
# produces an English translation, not a Hindi transcript).
data = {
"model": POLLEN_TRANSCRIBE_MODEL,
"response_format": "verbose_json",
"temperature": 0,
"timestamp_granularities[]": "word",
}
if language and language.lower() not in ("auto", ""):
data["language"] = language
response = requests.post(
POLLINATIONS_URL,
headers=headers,
files=files,
data=data,
timeout=120,
)
response.raise_for_status()
result = response.json()
segments = _normalise_segments(result.get("segments", []))
if not any(seg.get("words") for seg in segments) and "words" in result:
_assign_words_to_segments(segments, result["words"])
return _normalise_segments(segments)
def _segments_from_mlx(audio_path: str, language: str) -> list[dict]:
"""Run mlx-whisper locally."""
print("[s2] Using mlx-whisper backend...")
try:
import mlx_whisper
except ImportError:
raise ImportError("mlx-whisper is not installed. Run: uv add mlx-whisper")
result = mlx_whisper.transcribe(
audio_path,
path_or_hf_repo=MLX_MODEL,
language=language if language != "auto" else None,
word_timestamps=True,
)
return _normalise_segments(result.get("segments", []))
def _get_faster_whisper_model(device: str, compute_type: str):
"""Load/cached faster-whisper model."""
from faster_whisper import WhisperModel
key = (device, compute_type)
if key not in _FASTER_WHISPER_MODELS:
_FASTER_WHISPER_MODELS[key] = WhisperModel(
FASTER_WHISPER_MODEL,
device=device,
compute_type=compute_type,
)
return _FASTER_WHISPER_MODELS[key]
def _segments_from_faster_whisper_impl(
audio_path: str,
language: str,
device: str,
compute_type: str,
) -> list[dict]:
model = _get_faster_whisper_model(device=device, compute_type=compute_type)
segments, _ = model.transcribe(
audio_path,
language=None if language == "auto" else language,
word_timestamps=True,
)
output = []
for seg in segments:
words = []
for word in seg.words or []:
if word.start is None or word.end is None:
continue
words.append(
{
"word": str(word.word or "").strip(),
"start": float(word.start),
"end": float(word.end),
}
)
output.append(
{
"start": float(seg.start),
"end": float(seg.end),
"text": str(seg.text or "").strip(),
"words": words,
}
)
return output
def _segments_from_faster_whisper_cpu(
audio_path: str,
language: str,
) -> list[dict]:
"""CPU-only faster-whisper (no GPU decorator — runs outside ZeroGPU budget)."""
return _segments_from_faster_whisper_impl(audio_path, language, "cpu", "int8")
def _get_openai_whisper_model():
"""Load openai-whisper once per process. CUDA if available."""
global _OPENAI_WHISPER_MODEL
if _OPENAI_WHISPER_MODEL is None:
try:
import whisper as openai_whisper
except ImportError as exc:
raise ImportError("openai-whisper is not installed") from exc
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[s2] Loading openai-whisper ({OPENAI_WHISPER_MODEL}) on {device}...")
_OPENAI_WHISPER_MODEL = openai_whisper.load_model(OPENAI_WHISPER_MODEL, device=device)
return _OPENAI_WHISPER_MODEL
@spaces.GPU(duration=60)
def _segments_from_openai_whisper(
audio_path: str,
language: str,
) -> list[dict]:
"""GPU-decorated openai-whisper execution (PyTorch-native, ZeroGPU-compatible)."""
model = _get_openai_whisper_model()
result = model.transcribe(
audio_path,
language=None if language == "auto" else language,
word_timestamps=True,
verbose=False,
)
return _normalise_segments(result.get("segments", []))
def _segments_from_local_backend(audio_path: str, language: str) -> list[dict]:
"""Dispatch local whisper backend from runtime device detection."""
backend = _get_local_whisper_backend()
if backend == "mlx-whisper":
return _segments_from_mlx(audio_path, language)
if backend == "openai-whisper-cuda":
print("[s2] Using openai-whisper backend (cuda)...")
try:
return _segments_from_openai_whisper(audio_path, language)
except ImportError:
print("[s2] openai-whisper unavailable; falling back to faster-whisper (cpu).")
return _segments_from_faster_whisper_cpu(audio_path, language)
print("[s2] Using faster-whisper backend (cpu)...")
return _segments_from_faster_whisper_cpu(audio_path, language)
def transcribe(audio_path: str, language: str = "en") -> list[dict]:
"""
Transcribe audio and return canonical segment schema.
Priority:
1. Pollinations API (fast, offloads computation)
2. Local backend (GPU/MPS if available, otherwise CPU)
"""
print(f"[s2] Transcribing {audio_path} (lang={language})...")
segments = None
pollinations_error = None
local_error = None
# 1. Try Pollinations API first
try:
print("[s2] Trying Pollinations API...")
segments = _segments_from_pollinations(audio_path, language)
if segments:
print(f"[s2] Pollinations returned {len(segments)} segments ✓")
else:
segments = None
except Exception as exc:
print(f"[s2] Pollinations error ({exc}) — falling back to local backend.")
pollinations_error = exc
segments = None
# 2. Try Local Backend (GPU or CPU)
if segments is None:
try:
backend = _get_local_whisper_backend()
print(f"[s2] Trying local backend ({backend})...")
segments = _segments_from_local_backend(audio_path, language)
if segments:
print(f"[s2] Local backend returned {len(segments)} segments ✓")
except Exception as exc:
print(f"[s2] Local backend error ({exc}).")
local_error = exc
segments = None
if segments is None:
details = []
if pollinations_error is not None:
details.append(f"Pollinations: {pollinations_error}")
if local_error is not None:
details.append(f"Local backend: {local_error}")
suffix = f" Details: {' | '.join(details)}" if details else ""
raise RuntimeError(f"Transcription failed on all available backends.{suffix}")
before = len(segments)
segments = _split_oversized_segments(segments)
if len(segments) != before:
print(f"[s2] Split {before} oversized segment(s) → {len(segments)} segments")
return _normalise_segments(segments)