Spaces:
Running on A100
Running on A100
File size: 6,470 Bytes
2c4c098 9ef4867 257b104 9ef4867 2c4c098 9ef4867 2c4c098 9ef4867 2c4c098 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | # Copyright 2025 Xiaomi Corporation.
import os
import time
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from src.mimo_audio.mimo_audio import MimoAudio
MODEL_REPO = "XiaomiMiMo/MiMo-V2.5-ASR"
TOKENIZER_REPO = "XiaomiMiMo/MiMo-Audio-Tokenizer"
DOWNLOAD_ROOT = os.environ.get("MIMO_DOWNLOAD_ROOT", "assets/models")
LANGUAGE_TAGS = {
"Auto": "",
"Chinese": "<chinese>",
"English": "<english>",
}
def download_models():
os.makedirs(DOWNLOAD_ROOT, exist_ok=True)
hf_token = os.getenv("HF_TOKEN")
model_path = os.path.join(DOWNLOAD_ROOT, MODEL_REPO.replace("/", "_"))
tokenizer_path = os.path.join(DOWNLOAD_ROOT, TOKENIZER_REPO.replace("/", "_"))
print(f"[download] {MODEL_REPO} -> {model_path}")
snapshot_download(repo_id=MODEL_REPO, token=hf_token, local_dir=model_path)
print(f"[download] {TOKENIZER_REPO} -> {tokenizer_path}")
snapshot_download(repo_id=TOKENIZER_REPO, token=hf_token, local_dir=tokenizer_path)
return model_path, tokenizer_path
class ASRGenerator:
def __init__(self, model):
self.model = model
def transcribe(self, audio_path, audio_tag=""):
return self.model.asr_sft(audio_path, audio_tag=audio_tag)
class MiMoV25ASRInterface:
def __init__(self, model_path, tokenizer_path):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[init] device={device}")
print(f"[init] model_path={model_path}")
print(f"[init] tokenizer_path={tokenizer_path}")
self.model = MimoAudio(model_path, tokenizer_path)
self.asr_generator = ASRGenerator(self.model)
print("[init] model ready")
def transcribe(self, uploaded_audio, recorded_audio, language_choice):
audio_path = uploaded_audio or recorded_audio
if audio_path is None:
return "", "β Error: Please upload an audio file or record from your microphone."
audio_tag = LANGUAGE_TAGS.get(language_choice, "")
try:
print(f"Performing ASR task:")
print(f" Audio: {audio_path}")
print(f" Language: {language_choice} (tag='{audio_tag}')")
start = time.time()
transcript = self.asr_generator.transcribe(audio_path, audio_tag=audio_tag)
elapsed = time.time() - start
status_msg = (
f"β
Transcription completed in {elapsed:.2f}s\n"
f"π΅ Input audio: {os.path.basename(audio_path)}\n"
f"π Language tag: {language_choice}"
)
return transcript, status_msg
except Exception as e:
error_msg = f"β Error during transcription: {str(e)}"
print(error_msg)
return "", error_msg
def create_interface(self):
with gr.Blocks(
title="MiMo-V2.5-ASR Speech Recognition",
theme=gr.themes.Soft(),
fill_height=True,
analytics_enabled=False,
) as iface:
gr.Markdown("# ποΈ MiMo-V2.5-ASR: Robust Speech Recognition")
gr.Markdown(
"Upload an audio file **or** record directly from your microphone. "
"Supports Chinese, English, Chinese dialects, code-switch, singing, "
"noisy environments, and multi-speaker scenarios."
)
with gr.Row():
with gr.Column():
uploaded_audio = gr.Audio(
label="Upload Audio File",
type="filepath",
sources=["upload"],
interactive=True,
)
recorded_audio = gr.Audio(
label="Or Record from Microphone",
type="filepath",
sources=["microphone"],
interactive=True,
)
language_choice = gr.Radio(
label="Language Tag",
choices=list(LANGUAGE_TAGS.keys()),
value="Auto",
info=(
"Auto: automatic language detection (recommended for "
"code-switched speech). Select Chinese or English to "
"bias the model toward that language."
),
)
transcribe_btn = gr.Button(
"π§ Transcribe", variant="primary", size="lg"
)
with gr.Column():
output_text = gr.Textbox(
label="Transcription",
lines=10,
interactive=False,
placeholder="Transcription result will appear here...",
show_copy_button=True,
)
status = gr.Textbox(
label="Status",
lines=4,
interactive=False,
placeholder="Processing status will be shown here...",
)
with gr.Row():
clear_btn = gr.Button("ποΈ Clear", size="sm")
transcribe_btn.click(
fn=self.transcribe,
inputs=[uploaded_audio, recorded_audio, language_choice],
outputs=[output_text, status],
)
def clear_all():
return None, None, "Auto", "", ""
clear_btn.click(
fn=clear_all,
outputs=[
uploaded_audio,
recorded_audio,
language_choice,
output_text,
status,
],
)
return iface
def main():
print("π Launch MiMo-V2.5-ASR demo...")
model_path, tokenizer_path = download_models()
interface = MiMoV25ASRInterface(model_path, tokenizer_path)
iface = interface.create_interface()
host = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
port = int(os.environ.get("GRADIO_SERVER_PORT", "7898"))
print(f"π Launch service - {host}:{port}")
iface.queue(default_concurrency_limit=4, max_size=20).launch(
server_name=host,
server_port=port,
show_api=False,
)
if __name__ == "__main__":
main()
|