Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from faster_whisper import WhisperModel | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| ROOT = Path(__file__).resolve().parent | |
| STATIC_DIR = ROOT / "static" | |
| DEFAULT_WHISPER_MODEL = os.getenv("WHISPERMATH_WHISPER_MODEL", "small.en") | |
| DEFAULT_WHISPER_DEVICE = os.getenv("WHISPERMATH_WHISPER_DEVICE", "cpu") | |
| DEFAULT_WHISPER_COMPUTE_TYPE = os.getenv("WHISPERMATH_WHISPER_COMPUTE_TYPE", "int8") | |
| DEFAULT_DECODER_MODEL = os.getenv( | |
| "WHISPERMATH_DECODER_MODEL", | |
| "vibhuiitj/byt5-base-whispermath-a100-checkpoint-10724", | |
| ) | |
| def select_decoder_device(device_name: str | None = None) -> torch.device: | |
| if device_name and device_name != "auto": | |
| return torch.device(device_name) | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| class DemoModels: | |
| whisper: WhisperModel | |
| tokenizer: Any | |
| decoder: Any | |
| decoder_device: torch.device | |
| class DecodeRequest(BaseModel): | |
| text: str | |
| num_beams: int = 4 | |
| max_new_tokens: int = 256 | |
| def load_models() -> DemoModels: | |
| decoder_device = select_decoder_device(os.getenv("WHISPERMATH_DECODER_DEVICE", "auto")) | |
| print( | |
| f"Loading Whisper {DEFAULT_WHISPER_MODEL} " | |
| f"({DEFAULT_WHISPER_DEVICE}, {DEFAULT_WHISPER_COMPUTE_TYPE})...", | |
| flush=True, | |
| ) | |
| whisper = WhisperModel( | |
| DEFAULT_WHISPER_MODEL, | |
| device=DEFAULT_WHISPER_DEVICE, | |
| compute_type=DEFAULT_WHISPER_COMPUTE_TYPE, | |
| ) | |
| print(f"Loading decoder {DEFAULT_DECODER_MODEL} on {decoder_device}...", flush=True) | |
| tokenizer = AutoTokenizer.from_pretrained(DEFAULT_DECODER_MODEL) | |
| decoder = AutoModelForSeq2SeqLM.from_pretrained( | |
| DEFAULT_DECODER_MODEL, | |
| low_cpu_mem_usage=True, | |
| ).to(decoder_device) | |
| decoder.eval() | |
| print("WhisperMath web demo is ready.", flush=True) | |
| return DemoModels( | |
| whisper=whisper, | |
| tokenizer=tokenizer, | |
| decoder=decoder, | |
| decoder_device=decoder_device, | |
| ) | |
| models: DemoModels | None = None | |
| app = FastAPI(title="WhisperMath Web Demo") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| def startup() -> None: | |
| global models | |
| models = load_models() | |
| def index() -> FileResponse: | |
| return FileResponse(STATIC_DIR / "index.html") | |
| def health() -> dict[str, str]: | |
| decoder_device = str(models.decoder_device) if models else "not_loaded" | |
| return { | |
| "status": "ok" if models else "loading", | |
| "whisper_model": DEFAULT_WHISPER_MODEL, | |
| "decoder_model": DEFAULT_DECODER_MODEL, | |
| "decoder_device": decoder_device, | |
| } | |
| def transcribe_audio(audio_path: Path) -> tuple[str, list[dict[str, float | str]]]: | |
| if models is None: | |
| raise RuntimeError("Models are still loading.") | |
| segments, _info = models.whisper.transcribe( | |
| str(audio_path), | |
| language="en", | |
| beam_size=5, | |
| vad_filter=True, | |
| ) | |
| segment_rows = [ | |
| { | |
| "start": round(segment.start, 3), | |
| "end": round(segment.end, 3), | |
| "text": segment.text.strip(), | |
| } | |
| for segment in segments | |
| ] | |
| transcript = " ".join(row["text"] for row in segment_rows).strip() | |
| return transcript, segment_rows | |
| def decode_math_text( | |
| transcript: str, | |
| max_source_length: int = 512, | |
| max_new_tokens: int = 256, | |
| num_beams: int = 4, | |
| repetition_penalty: float = 1.25, | |
| no_repeat_ngram_size: int = 4, | |
| ) -> str: | |
| if models is None: | |
| raise RuntimeError("Models are still loading.") | |
| if not transcript: | |
| return "" | |
| encoded = models.tokenizer( | |
| transcript, | |
| return_tensors="pt", | |
| max_length=max_source_length, | |
| truncation=True, | |
| ).to(models.decoder_device) | |
| with torch.no_grad(): | |
| output_ids = models.decoder.generate( | |
| **encoded, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| early_stopping=num_beams > 1, | |
| ) | |
| return models.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() | |
| def clamp_generation_args(num_beams: int, max_new_tokens: int) -> tuple[int, int]: | |
| return max(1, min(num_beams, 8)), max(32, min(max_new_tokens, 1024)) | |
| async def transcribe( | |
| audio: UploadFile = File(...), | |
| num_beams: int = Form(4), | |
| max_new_tokens: int = Form(256), | |
| ) -> dict[str, Any]: | |
| if models is None: | |
| raise HTTPException(status_code=503, detail="Models are still loading.") | |
| suffix = Path(audio.filename or "recording.webm").suffix or ".webm" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as handle: | |
| temp_path = Path(handle.name) | |
| handle.write(await audio.read()) | |
| try: | |
| transcript, segments = transcribe_audio(temp_path) | |
| safe_num_beams, safe_max_new_tokens = clamp_generation_args(num_beams, max_new_tokens) | |
| math_text = decode_math_text( | |
| transcript, | |
| num_beams=safe_num_beams, | |
| max_new_tokens=safe_max_new_tokens, | |
| ) | |
| except Exception as exc: # pragma: no cover - returned to the demo UI. | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| finally: | |
| temp_path.unlink(missing_ok=True) | |
| return { | |
| "transcript": transcript, | |
| "math_text": math_text, | |
| "segments": segments, | |
| "whisper_model": DEFAULT_WHISPER_MODEL, | |
| "decoder_model": DEFAULT_DECODER_MODEL, | |
| } | |
| def decode(request: DecodeRequest) -> dict[str, Any]: | |
| if models is None: | |
| raise HTTPException(status_code=503, detail="Models are still loading.") | |
| transcript = request.text.strip() | |
| if not transcript: | |
| raise HTTPException(status_code=400, detail="Text is required.") | |
| safe_num_beams, safe_max_new_tokens = clamp_generation_args( | |
| request.num_beams, | |
| request.max_new_tokens, | |
| ) | |
| math_text = decode_math_text( | |
| transcript, | |
| num_beams=safe_num_beams, | |
| max_new_tokens=safe_max_new_tokens, | |
| ) | |
| return { | |
| "transcript": transcript, | |
| "math_text": math_text, | |
| "decoder_model": DEFAULT_DECODER_MODEL, | |
| } | |