Rafii commited on
Commit
abc7c46
·
1 Parent(s): fc71180

deploy: switch to chatterbox requirements @ 68ada45

Browse files
Files changed (1) hide show
  1. steps/s2_transcribe.py +63 -14
steps/s2_transcribe.py CHANGED
@@ -24,20 +24,56 @@ POLLEN_TRANSCRIBE_MODEL = os.getenv("POLLEN_TRANSCRIBE_MODEL", "whisper-large-v3
24
  MLX_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-mlx")
25
  FASTER_WHISPER_MODEL = os.getenv("FASTER_WHISPER_MODEL", "large-v3")
26
  OPENAI_WHISPER_MODEL = os.getenv("OPENAI_WHISPER_MODEL", "large-v3")
27
-
28
- if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
29
- LOCAL_WHISPER_BACKEND = "mlx-whisper"
30
- elif torch.cuda.is_available():
31
- # PyTorch-based path so @spaces.GPU can intercept the CUDA allocation.
32
- # faster-whisper uses CTranslate2 which bypasses PyTorch and breaks ZeroGPU.
33
- LOCAL_WHISPER_BACKEND = "openai-whisper-cuda"
34
- else:
35
- LOCAL_WHISPER_BACKEND = "faster-whisper-cpu"
36
 
37
  _FASTER_WHISPER_MODELS = {}
38
  _OPENAI_WHISPER_MODEL = None
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def _extract_words(raw_words: list[dict]) -> list[dict]:
42
  """Normalise word timestamps into {word, start, end}."""
43
  output = []
@@ -279,11 +315,13 @@ def _segments_from_openai_whisper(
279
 
280
 
281
  def _segments_from_local_backend(audio_path: str, language: str) -> list[dict]:
282
- """Dispatch local whisper backend from startup device detection."""
283
- if LOCAL_WHISPER_BACKEND == "mlx-whisper":
 
 
284
  return _segments_from_mlx(audio_path, language)
285
 
286
- if LOCAL_WHISPER_BACKEND == "openai-whisper-cuda":
287
  print("[s2] Using openai-whisper backend (cuda)...")
288
  try:
289
  return _segments_from_openai_whisper(audio_path, language)
@@ -306,6 +344,8 @@ def transcribe(audio_path: str, language: str = "en") -> list[dict]:
306
  print(f"[s2] Transcribing {audio_path} (lang={language})...")
307
 
308
  segments = None
 
 
309
 
310
  # 1. Try Pollinations API first
311
  try:
@@ -317,21 +357,30 @@ def transcribe(audio_path: str, language: str = "en") -> list[dict]:
317
  segments = None
318
  except Exception as exc:
319
  print(f"[s2] Pollinations error ({exc}) — falling back to local backend.")
 
320
  segments = None
321
 
322
  # 2. Try Local Backend (GPU or CPU)
323
  if segments is None:
324
  try:
325
- print(f"[s2] Trying local backend ({LOCAL_WHISPER_BACKEND})...")
 
326
  segments = _segments_from_local_backend(audio_path, language)
327
  if segments:
328
  print(f"[s2] Local backend returned {len(segments)} segments ✓")
329
  except Exception as exc:
330
  print(f"[s2] Local backend error ({exc}).")
 
331
  segments = None
332
 
333
  if segments is None:
334
- raise RuntimeError("Transcription failed on all available backends.")
 
 
 
 
 
 
335
 
336
  before = len(segments)
337
  segments = _split_oversized_segments(segments)
 
24
  MLX_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-mlx")
25
  FASTER_WHISPER_MODEL = os.getenv("FASTER_WHISPER_MODEL", "large-v3")
26
  OPENAI_WHISPER_MODEL = os.getenv("OPENAI_WHISPER_MODEL", "large-v3")
27
+ LOCAL_WHISPER_BACKEND_ENV = "VIDEOVOICE_WHISPER_BACKEND"
28
+ _VALID_LOCAL_BACKENDS = {
29
+ "mlx-whisper",
30
+ "openai-whisper-cuda",
31
+ "faster-whisper-cpu",
32
+ }
 
 
 
33
 
34
  _FASTER_WHISPER_MODELS = {}
35
  _OPENAI_WHISPER_MODEL = None
36
 
37
 
38
+ def _running_on_hf_space() -> bool:
39
+ return bool(
40
+ os.getenv("SPACE_ID")
41
+ or os.getenv("SPACE_HOST")
42
+ or os.getenv("HF_SPACE_ID")
43
+ )
44
+
45
+
46
+ def _get_local_whisper_backend() -> str:
47
+ """
48
+ Resolve the local transcription backend lazily.
49
+
50
+ On HF Spaces, default to CPU faster-whisper unless explicitly overridden.
51
+ ZeroGPU can report CUDA availability outside an active @spaces.GPU call,
52
+ which makes import-time backend selection unreliable.
53
+ """
54
+ override = os.getenv(LOCAL_WHISPER_BACKEND_ENV, "").strip().lower()
55
+ if override:
56
+ if override not in _VALID_LOCAL_BACKENDS:
57
+ raise ValueError(
58
+ f"Invalid {LOCAL_WHISPER_BACKEND_ENV}={override!r}. "
59
+ f"Expected one of: {', '.join(sorted(_VALID_LOCAL_BACKENDS))}."
60
+ )
61
+ return override
62
+
63
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
64
+ return "mlx-whisper"
65
+
66
+ if _running_on_hf_space():
67
+ return "faster-whisper-cpu"
68
+
69
+ if torch.cuda.is_available():
70
+ # PyTorch-based path so @spaces.GPU can intercept the CUDA allocation.
71
+ # faster-whisper uses CTranslate2 which bypasses PyTorch and breaks ZeroGPU.
72
+ return "openai-whisper-cuda"
73
+
74
+ return "faster-whisper-cpu"
75
+
76
+
77
  def _extract_words(raw_words: list[dict]) -> list[dict]:
78
  """Normalise word timestamps into {word, start, end}."""
79
  output = []
 
315
 
316
 
317
  def _segments_from_local_backend(audio_path: str, language: str) -> list[dict]:
318
+ """Dispatch local whisper backend from runtime device detection."""
319
+ backend = _get_local_whisper_backend()
320
+
321
+ if backend == "mlx-whisper":
322
  return _segments_from_mlx(audio_path, language)
323
 
324
+ if backend == "openai-whisper-cuda":
325
  print("[s2] Using openai-whisper backend (cuda)...")
326
  try:
327
  return _segments_from_openai_whisper(audio_path, language)
 
344
  print(f"[s2] Transcribing {audio_path} (lang={language})...")
345
 
346
  segments = None
347
+ pollinations_error = None
348
+ local_error = None
349
 
350
  # 1. Try Pollinations API first
351
  try:
 
357
  segments = None
358
  except Exception as exc:
359
  print(f"[s2] Pollinations error ({exc}) — falling back to local backend.")
360
+ pollinations_error = exc
361
  segments = None
362
 
363
  # 2. Try Local Backend (GPU or CPU)
364
  if segments is None:
365
  try:
366
+ backend = _get_local_whisper_backend()
367
+ print(f"[s2] Trying local backend ({backend})...")
368
  segments = _segments_from_local_backend(audio_path, language)
369
  if segments:
370
  print(f"[s2] Local backend returned {len(segments)} segments ✓")
371
  except Exception as exc:
372
  print(f"[s2] Local backend error ({exc}).")
373
+ local_error = exc
374
  segments = None
375
 
376
  if segments is None:
377
+ details = []
378
+ if pollinations_error is not None:
379
+ details.append(f"Pollinations: {pollinations_error}")
380
+ if local_error is not None:
381
+ details.append(f"Local backend: {local_error}")
382
+ suffix = f" Details: {' | '.join(details)}" if details else ""
383
+ raise RuntimeError(f"Transcription failed on all available backends.{suffix}")
384
 
385
  before = len(segments)
386
  segments = _split_oversized_segments(segments)