File size: 4,238 Bytes
05ccc32
 
 
4306377
 
05ccc32
 
 
 
 
 
 
 
521b74d
05ccc32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4306377
05ccc32
4306377
05ccc32
4306377
05ccc32
4306377
05ccc32
4306377
05ccc32
 
 
 
 
 
 
 
4306377
05ccc32
4306377
05ccc32
 
 
4306377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05ccc32
 
 
 
 
 
 
 
 
 
 
 
 
 
4306377
 
 
 
c43c11e
4306377
 
 
05ccc32
 
 
 
 
 
 
 
4306377
05ccc32
1cd819a
 
 
 
 
 
 
 
 
c43c11e
1cd819a
 
 
05ccc32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os

import gradio as gr
import librosa
import spaces
import torch
from huggingface_hub import login
from transformers import (
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    WhisperTokenizerFast,
)


_orig_init = WhisperTokenizerFast.__init__


def _patched_init(self, *args, **kwargs):
    est = kwargs.get("extra_special_tokens")
    if isinstance(est, list):
        kwargs["extra_special_tokens"] = {t: t for t in est}
    return _orig_init(self, *args, **kwargs)


WhisperTokenizerFast.__init__ = _patched_init


HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

MODEL_REPO = os.getenv("MODEL_REPO_ID", "Batuka0901/MN_ASR")
SAMPLING_RATE = 16000

print(f"Loading {MODEL_REPO} (CPU at startup) ...")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_REPO, token=HF_TOKEN)
model.eval()
tokenizer = WhisperTokenizerFast.from_pretrained(MODEL_REPO, token=HF_TOKEN)
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_REPO, token=HF_TOKEN)
print("Model loaded.")
_moved_to_cuda = False


WAITING = "Төлөв: **Аудио хүлээж байна...**"
READY = "Төлөв: **Илгээхэд бэлэн.**"
WORKING = "Төлөв: **Танилт хийж байна...**"
DONE = "Төлөв: **Дууссан.**"


@spaces.GPU(duration=60)
def transcribe(audio_path):
    global _moved_to_cuda
    if not audio_path:
        return "", WAITING
    try:
        if not _moved_to_cuda and torch.cuda.is_available():
            model.to("cuda")
            _moved_to_cuda = True
        device = "cuda" if (_moved_to_cuda and torch.cuda.is_available()) else "cpu"

        audio, _ = librosa.load(audio_path, sr=SAMPLING_RATE)
        inputs = feature_extractor(
            audio, sampling_rate=SAMPLING_RATE, return_tensors="pt"
        )
        input_features = inputs.input_features.to(device)

        with torch.no_grad():
            predicted_ids = model.generate(
                input_features, language="mn", task="transcribe"
            )
        text = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        return text.strip(), DONE
    except Exception as e:
        return "", f"Төлөв: **Алдаа** — {type(e).__name__}: {e}"


def on_audio_change(audio_path):
    if audio_path:
        return gr.update(interactive=True), READY
    return gr.update(interactive=False), WAITING


def on_clear():
    return gr.update(interactive=False), WAITING


INSTRUCTIONS = """
### Заавар

1. **Audio оруулна уу** — файл upload хийх эсвэл микрофоноор шууд бичлэг хийнэ
2. **Generate** товчийг дарна — таны хэлсэн үгийг загвар таньж текст болгоно

"""

CSS = """
footer { display: none !important; visibility: hidden !important; }
.gradio-container > .footer { display: none !important; }
button.api-link, .api-docs, a[href*="/api/"] { display: none !important; }
"""

with gr.Blocks(title="Speech to Text", css=CSS) as demo:
    with gr.Tab("Speech to Text"):
        gr.Markdown(INSTRUCTIONS)
        with gr.Row():
            with gr.Column(scale=1):
                audio_in = gr.Audio(
                    sources=["upload", "microphone"],
                    type="filepath",
                    label="Audio",
                )
                with gr.Row():
                    clear_btn = gr.Button("Clear", variant="secondary", size="sm")
                    submit_btn = gr.Button(
                        "Generate", variant="primary", size="sm", interactive=False
                    )
            with gr.Column(scale=1):
                text_out = gr.Textbox(label="Гаралт", lines=10)

        status = gr.Markdown(WAITING)

        audio_in.change(on_audio_change, inputs=audio_in, outputs=[submit_btn, status])
        submit_btn.click(lambda: WORKING, outputs=status).then(
            transcribe, inputs=audio_in, outputs=[text_out, status]
        )
        clear_btn.click(
            lambda: (None, "", WAITING),
            outputs=[audio_in, text_out, status],
        ).then(on_clear, outputs=[submit_btn, status])


if __name__ == "__main__":
    demo.queue().launch(show_api=False)