vibhuiitj's picture
Deploy WhisperMath web demo
95c3887 verified
from __future__ import annotations
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import torch
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from faster_whisper import WhisperModel
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
ROOT = Path(__file__).resolve().parent
STATIC_DIR = ROOT / "static"
DEFAULT_WHISPER_MODEL = os.getenv("WHISPERMATH_WHISPER_MODEL", "small.en")
DEFAULT_WHISPER_DEVICE = os.getenv("WHISPERMATH_WHISPER_DEVICE", "cpu")
DEFAULT_WHISPER_COMPUTE_TYPE = os.getenv("WHISPERMATH_WHISPER_COMPUTE_TYPE", "int8")
DEFAULT_DECODER_MODEL = os.getenv(
"WHISPERMATH_DECODER_MODEL",
"vibhuiitj/byt5-base-whispermath-a100-checkpoint-10724",
)
def select_decoder_device(device_name: str | None = None) -> torch.device:
if device_name and device_name != "auto":
return torch.device(device_name)
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
@dataclass
class DemoModels:
whisper: WhisperModel
tokenizer: Any
decoder: Any
decoder_device: torch.device
class DecodeRequest(BaseModel):
text: str
num_beams: int = 4
max_new_tokens: int = 256
def load_models() -> DemoModels:
decoder_device = select_decoder_device(os.getenv("WHISPERMATH_DECODER_DEVICE", "auto"))
print(
f"Loading Whisper {DEFAULT_WHISPER_MODEL} "
f"({DEFAULT_WHISPER_DEVICE}, {DEFAULT_WHISPER_COMPUTE_TYPE})...",
flush=True,
)
whisper = WhisperModel(
DEFAULT_WHISPER_MODEL,
device=DEFAULT_WHISPER_DEVICE,
compute_type=DEFAULT_WHISPER_COMPUTE_TYPE,
)
print(f"Loading decoder {DEFAULT_DECODER_MODEL} on {decoder_device}...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_DECODER_MODEL)
decoder = AutoModelForSeq2SeqLM.from_pretrained(
DEFAULT_DECODER_MODEL,
low_cpu_mem_usage=True,
).to(decoder_device)
decoder.eval()
print("WhisperMath web demo is ready.", flush=True)
return DemoModels(
whisper=whisper,
tokenizer=tokenizer,
decoder=decoder,
decoder_device=decoder_device,
)
models: DemoModels | None = None
app = FastAPI(title="WhisperMath Web Demo")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
@app.on_event("startup")
def startup() -> None:
global models
models = load_models()
@app.get("/")
def index() -> FileResponse:
return FileResponse(STATIC_DIR / "index.html")
@app.get("/api/health")
def health() -> dict[str, str]:
decoder_device = str(models.decoder_device) if models else "not_loaded"
return {
"status": "ok" if models else "loading",
"whisper_model": DEFAULT_WHISPER_MODEL,
"decoder_model": DEFAULT_DECODER_MODEL,
"decoder_device": decoder_device,
}
def transcribe_audio(audio_path: Path) -> tuple[str, list[dict[str, float | str]]]:
if models is None:
raise RuntimeError("Models are still loading.")
segments, _info = models.whisper.transcribe(
str(audio_path),
language="en",
beam_size=5,
vad_filter=True,
)
segment_rows = [
{
"start": round(segment.start, 3),
"end": round(segment.end, 3),
"text": segment.text.strip(),
}
for segment in segments
]
transcript = " ".join(row["text"] for row in segment_rows).strip()
return transcript, segment_rows
def decode_math_text(
transcript: str,
max_source_length: int = 512,
max_new_tokens: int = 256,
num_beams: int = 4,
repetition_penalty: float = 1.25,
no_repeat_ngram_size: int = 4,
) -> str:
if models is None:
raise RuntimeError("Models are still loading.")
if not transcript:
return ""
encoded = models.tokenizer(
transcript,
return_tensors="pt",
max_length=max_source_length,
truncation=True,
).to(models.decoder_device)
with torch.no_grad():
output_ids = models.decoder.generate(
**encoded,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=num_beams > 1,
)
return models.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
def clamp_generation_args(num_beams: int, max_new_tokens: int) -> tuple[int, int]:
return max(1, min(num_beams, 8)), max(32, min(max_new_tokens, 1024))
@app.post("/api/transcribe")
async def transcribe(
audio: UploadFile = File(...),
num_beams: int = Form(4),
max_new_tokens: int = Form(256),
) -> dict[str, Any]:
if models is None:
raise HTTPException(status_code=503, detail="Models are still loading.")
suffix = Path(audio.filename or "recording.webm").suffix or ".webm"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as handle:
temp_path = Path(handle.name)
handle.write(await audio.read())
try:
transcript, segments = transcribe_audio(temp_path)
safe_num_beams, safe_max_new_tokens = clamp_generation_args(num_beams, max_new_tokens)
math_text = decode_math_text(
transcript,
num_beams=safe_num_beams,
max_new_tokens=safe_max_new_tokens,
)
except Exception as exc: # pragma: no cover - returned to the demo UI.
raise HTTPException(status_code=500, detail=str(exc)) from exc
finally:
temp_path.unlink(missing_ok=True)
return {
"transcript": transcript,
"math_text": math_text,
"segments": segments,
"whisper_model": DEFAULT_WHISPER_MODEL,
"decoder_model": DEFAULT_DECODER_MODEL,
}
@app.post("/api/decode")
def decode(request: DecodeRequest) -> dict[str, Any]:
if models is None:
raise HTTPException(status_code=503, detail="Models are still loading.")
transcript = request.text.strip()
if not transcript:
raise HTTPException(status_code=400, detail="Text is required.")
safe_num_beams, safe_max_new_tokens = clamp_generation_args(
request.num_beams,
request.max_new_tokens,
)
math_text = decode_math_text(
transcript,
num_beams=safe_num_beams,
max_new_tokens=safe_max_new_tokens,
)
return {
"transcript": transcript,
"math_text": math_text,
"decoder_model": DEFAULT_DECODER_MODEL,
}