Spaces:
Running on Zero
Running on Zero
File size: 4,238 Bytes
05ccc32 4306377 05ccc32 521b74d 05ccc32 4306377 05ccc32 4306377 05ccc32 4306377 05ccc32 4306377 05ccc32 4306377 05ccc32 4306377 05ccc32 4306377 05ccc32 4306377 05ccc32 4306377 c43c11e 4306377 05ccc32 4306377 05ccc32 1cd819a c43c11e 1cd819a 05ccc32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import os
import gradio as gr
import librosa
import spaces
import torch
from huggingface_hub import login
from transformers import (
WhisperFeatureExtractor,
WhisperForConditionalGeneration,
WhisperTokenizerFast,
)
_orig_init = WhisperTokenizerFast.__init__
def _patched_init(self, *args, **kwargs):
est = kwargs.get("extra_special_tokens")
if isinstance(est, list):
kwargs["extra_special_tokens"] = {t: t for t in est}
return _orig_init(self, *args, **kwargs)
WhisperTokenizerFast.__init__ = _patched_init
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
MODEL_REPO = os.getenv("MODEL_REPO_ID", "Batuka0901/MN_ASR")
SAMPLING_RATE = 16000
print(f"Loading {MODEL_REPO} (CPU at startup) ...")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_REPO, token=HF_TOKEN)
model.eval()
tokenizer = WhisperTokenizerFast.from_pretrained(MODEL_REPO, token=HF_TOKEN)
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_REPO, token=HF_TOKEN)
print("Model loaded.")
_moved_to_cuda = False
WAITING = "Төлөв: **Аудио хүлээж байна...**"
READY = "Төлөв: **Илгээхэд бэлэн.**"
WORKING = "Төлөв: **Танилт хийж байна...**"
DONE = "Төлөв: **Дууссан.**"
@spaces.GPU(duration=60)
def transcribe(audio_path):
global _moved_to_cuda
if not audio_path:
return "", WAITING
try:
if not _moved_to_cuda and torch.cuda.is_available():
model.to("cuda")
_moved_to_cuda = True
device = "cuda" if (_moved_to_cuda and torch.cuda.is_available()) else "cpu"
audio, _ = librosa.load(audio_path, sr=SAMPLING_RATE)
inputs = feature_extractor(
audio, sampling_rate=SAMPLING_RATE, return_tensors="pt"
)
input_features = inputs.input_features.to(device)
with torch.no_grad():
predicted_ids = model.generate(
input_features, language="mn", task="transcribe"
)
text = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return text.strip(), DONE
except Exception as e:
return "", f"Төлөв: **Алдаа** — {type(e).__name__}: {e}"
def on_audio_change(audio_path):
if audio_path:
return gr.update(interactive=True), READY
return gr.update(interactive=False), WAITING
def on_clear():
return gr.update(interactive=False), WAITING
INSTRUCTIONS = """
### Заавар
1. **Audio оруулна уу** — файл upload хийх эсвэл микрофоноор шууд бичлэг хийнэ
2. **Generate** товчийг дарна — таны хэлсэн үгийг загвар таньж текст болгоно
"""
CSS = """
footer { display: none !important; visibility: hidden !important; }
.gradio-container > .footer { display: none !important; }
button.api-link, .api-docs, a[href*="/api/"] { display: none !important; }
"""
with gr.Blocks(title="Speech to Text", css=CSS) as demo:
with gr.Tab("Speech to Text"):
gr.Markdown(INSTRUCTIONS)
with gr.Row():
with gr.Column(scale=1):
audio_in = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Audio",
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary", size="sm")
submit_btn = gr.Button(
"Generate", variant="primary", size="sm", interactive=False
)
with gr.Column(scale=1):
text_out = gr.Textbox(label="Гаралт", lines=10)
status = gr.Markdown(WAITING)
audio_in.change(on_audio_change, inputs=audio_in, outputs=[submit_btn, status])
submit_btn.click(lambda: WORKING, outputs=status).then(
transcribe, inputs=audio_in, outputs=[text_out, status]
)
clear_btn.click(
lambda: (None, "", WAITING),
outputs=[audio_in, text_out, status],
).then(on_clear, outputs=[submit_btn, status])
if __name__ == "__main__":
demo.queue().launch(show_api=False)
|