mn_stt / app.py
Batuka0901's picture
Update app.py
c43c11e verified
import os
import gradio as gr
import librosa
import spaces
import torch
from huggingface_hub import login
from transformers import (
WhisperFeatureExtractor,
WhisperForConditionalGeneration,
WhisperTokenizerFast,
)
_orig_init = WhisperTokenizerFast.__init__
def _patched_init(self, *args, **kwargs):
est = kwargs.get("extra_special_tokens")
if isinstance(est, list):
kwargs["extra_special_tokens"] = {t: t for t in est}
return _orig_init(self, *args, **kwargs)
WhisperTokenizerFast.__init__ = _patched_init
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
MODEL_REPO = os.getenv("MODEL_REPO_ID", "Batuka0901/MN_ASR")
SAMPLING_RATE = 16000
print(f"Loading {MODEL_REPO} (CPU at startup) ...")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_REPO, token=HF_TOKEN)
model.eval()
tokenizer = WhisperTokenizerFast.from_pretrained(MODEL_REPO, token=HF_TOKEN)
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_REPO, token=HF_TOKEN)
print("Model loaded.")
_moved_to_cuda = False
WAITING = "Төлөв: **Аудио хүлээж байна...**"
READY = "Төлөв: **Илгээхэд бэлэн.**"
WORKING = "Төлөв: **Танилт хийж байна...**"
DONE = "Төлөв: **Дууссан.**"
@spaces.GPU(duration=60)
def transcribe(audio_path):
global _moved_to_cuda
if not audio_path:
return "", WAITING
try:
if not _moved_to_cuda and torch.cuda.is_available():
model.to("cuda")
_moved_to_cuda = True
device = "cuda" if (_moved_to_cuda and torch.cuda.is_available()) else "cpu"
audio, _ = librosa.load(audio_path, sr=SAMPLING_RATE)
inputs = feature_extractor(
audio, sampling_rate=SAMPLING_RATE, return_tensors="pt"
)
input_features = inputs.input_features.to(device)
with torch.no_grad():
predicted_ids = model.generate(
input_features, language="mn", task="transcribe"
)
text = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return text.strip(), DONE
except Exception as e:
return "", f"Төлөв: **Алдаа** — {type(e).__name__}: {e}"
def on_audio_change(audio_path):
if audio_path:
return gr.update(interactive=True), READY
return gr.update(interactive=False), WAITING
def on_clear():
return gr.update(interactive=False), WAITING
INSTRUCTIONS = """
### Заавар
1. **Audio оруулна уу** — файл upload хийх эсвэл микрофоноор шууд бичлэг хийнэ
2. **Generate** товчийг дарна — таны хэлсэн үгийг загвар таньж текст болгоно
"""
CSS = """
footer { display: none !important; visibility: hidden !important; }
.gradio-container > .footer { display: none !important; }
button.api-link, .api-docs, a[href*="/api/"] { display: none !important; }
"""
with gr.Blocks(title="Speech to Text", css=CSS) as demo:
with gr.Tab("Speech to Text"):
gr.Markdown(INSTRUCTIONS)
with gr.Row():
with gr.Column(scale=1):
audio_in = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Audio",
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary", size="sm")
submit_btn = gr.Button(
"Generate", variant="primary", size="sm", interactive=False
)
with gr.Column(scale=1):
text_out = gr.Textbox(label="Гаралт", lines=10)
status = gr.Markdown(WAITING)
audio_in.change(on_audio_change, inputs=audio_in, outputs=[submit_btn, status])
submit_btn.click(lambda: WORKING, outputs=status).then(
transcribe, inputs=audio_in, outputs=[text_out, status]
)
clear_btn.click(
lambda: (None, "", WAITING),
outputs=[audio_in, text_out, status],
).then(on_clear, outputs=[submit_btn, status])
if __name__ == "__main__":
demo.queue().launch(show_api=False)