# Copyright 2025 Xiaomi Corporation. import os import time import gradio as gr import torch from huggingface_hub import snapshot_download from src.mimo_audio.mimo_audio import MimoAudio MODEL_REPO = "XiaomiMiMo/MiMo-V2.5-ASR" TOKENIZER_REPO = "XiaomiMiMo/MiMo-Audio-Tokenizer" DOWNLOAD_ROOT = os.environ.get("MIMO_DOWNLOAD_ROOT", "assets/models") LANGUAGE_TAGS = { "Auto": "", "Chinese": "", "English": "", } def download_models(): os.makedirs(DOWNLOAD_ROOT, exist_ok=True) hf_token = os.getenv("HF_TOKEN") model_path = os.path.join(DOWNLOAD_ROOT, MODEL_REPO.replace("/", "_")) tokenizer_path = os.path.join(DOWNLOAD_ROOT, TOKENIZER_REPO.replace("/", "_")) print(f"[download] {MODEL_REPO} -> {model_path}") snapshot_download(repo_id=MODEL_REPO, token=hf_token, local_dir=model_path) print(f"[download] {TOKENIZER_REPO} -> {tokenizer_path}") snapshot_download(repo_id=TOKENIZER_REPO, token=hf_token, local_dir=tokenizer_path) return model_path, tokenizer_path class ASRGenerator: def __init__(self, model): self.model = model def transcribe(self, audio_path, audio_tag=""): return self.model.asr_sft(audio_path, audio_tag=audio_tag) class MiMoV25ASRInterface: def __init__(self, model_path, tokenizer_path): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[init] device={device}") print(f"[init] model_path={model_path}") print(f"[init] tokenizer_path={tokenizer_path}") self.model = MimoAudio(model_path, tokenizer_path) self.asr_generator = ASRGenerator(self.model) print("[init] model ready") def transcribe(self, uploaded_audio, recorded_audio, language_choice): audio_path = uploaded_audio or recorded_audio if audio_path is None: return "", "❌ Error: Please upload an audio file or record from your microphone." audio_tag = LANGUAGE_TAGS.get(language_choice, "") try: print(f"Performing ASR task:") print(f" Audio: {audio_path}") print(f" Language: {language_choice} (tag='{audio_tag}')") start = time.time() transcript = self.asr_generator.transcribe(audio_path, audio_tag=audio_tag) elapsed = time.time() - start status_msg = ( f"✅ Transcription completed in {elapsed:.2f}s\n" f"🎵 Input audio: {os.path.basename(audio_path)}\n" f"🌐 Language tag: {language_choice}" ) return transcript, status_msg except Exception as e: error_msg = f"❌ Error during transcription: {str(e)}" print(error_msg) return "", error_msg def create_interface(self): with gr.Blocks( title="MiMo-V2.5-ASR Speech Recognition", theme=gr.themes.Soft(), fill_height=True, analytics_enabled=False, ) as iface: gr.Markdown("# 🎙️ MiMo-V2.5-ASR: Robust Speech Recognition") gr.Markdown( "Upload an audio file **or** record directly from your microphone. " "Supports Chinese, English, Chinese dialects, code-switch, singing, " "noisy environments, and multi-speaker scenarios." ) with gr.Row(): with gr.Column(): uploaded_audio = gr.Audio( label="Upload Audio File", type="filepath", sources=["upload"], interactive=True, ) recorded_audio = gr.Audio( label="Or Record from Microphone", type="filepath", sources=["microphone"], interactive=True, ) language_choice = gr.Radio( label="Language Tag", choices=list(LANGUAGE_TAGS.keys()), value="Auto", info=( "Auto: automatic language detection (recommended for " "code-switched speech). Select Chinese or English to " "bias the model toward that language." ), ) transcribe_btn = gr.Button( "🎧 Transcribe", variant="primary", size="lg" ) with gr.Column(): output_text = gr.Textbox( label="Transcription", lines=10, interactive=False, placeholder="Transcription result will appear here...", show_copy_button=True, ) status = gr.Textbox( label="Status", lines=4, interactive=False, placeholder="Processing status will be shown here...", ) with gr.Row(): clear_btn = gr.Button("🗑️ Clear", size="sm") transcribe_btn.click( fn=self.transcribe, inputs=[uploaded_audio, recorded_audio, language_choice], outputs=[output_text, status], ) def clear_all(): return None, None, "Auto", "", "" clear_btn.click( fn=clear_all, outputs=[ uploaded_audio, recorded_audio, language_choice, output_text, status, ], ) return iface def main(): print("🚀 Launch MiMo-V2.5-ASR demo...") model_path, tokenizer_path = download_models() interface = MiMoV25ASRInterface(model_path, tokenizer_path) iface = interface.create_interface() host = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") port = int(os.environ.get("GRADIO_SERVER_PORT", "7898")) print(f"🌐 Launch service - {host}:{port}") iface.queue(default_concurrency_limit=4, max_size=20).launch( server_name=host, server_port=port, show_api=False, ) if __name__ == "__main__": main()