Spaces:
Running on A100
Running on A100
| # Copyright 2025 Xiaomi Corporation. | |
| import os | |
| import time | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from src.mimo_audio.mimo_audio import MimoAudio | |
| MODEL_REPO = "XiaomiMiMo/MiMo-V2.5-ASR" | |
| TOKENIZER_REPO = "XiaomiMiMo/MiMo-Audio-Tokenizer" | |
| DOWNLOAD_ROOT = os.environ.get("MIMO_DOWNLOAD_ROOT", "assets/models") | |
| LANGUAGE_TAGS = { | |
| "Auto": "", | |
| "Chinese": "<chinese>", | |
| "English": "<english>", | |
| } | |
| def download_models(): | |
| os.makedirs(DOWNLOAD_ROOT, exist_ok=True) | |
| hf_token = os.getenv("HF_TOKEN") | |
| model_path = os.path.join(DOWNLOAD_ROOT, MODEL_REPO.replace("/", "_")) | |
| tokenizer_path = os.path.join(DOWNLOAD_ROOT, TOKENIZER_REPO.replace("/", "_")) | |
| print(f"[download] {MODEL_REPO} -> {model_path}") | |
| snapshot_download(repo_id=MODEL_REPO, token=hf_token, local_dir=model_path) | |
| print(f"[download] {TOKENIZER_REPO} -> {tokenizer_path}") | |
| snapshot_download(repo_id=TOKENIZER_REPO, token=hf_token, local_dir=tokenizer_path) | |
| return model_path, tokenizer_path | |
| class ASRGenerator: | |
| def __init__(self, model): | |
| self.model = model | |
| def transcribe(self, audio_path, audio_tag=""): | |
| return self.model.asr_sft(audio_path, audio_tag=audio_tag) | |
| class MiMoV25ASRInterface: | |
| def __init__(self, model_path, tokenizer_path): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[init] device={device}") | |
| print(f"[init] model_path={model_path}") | |
| print(f"[init] tokenizer_path={tokenizer_path}") | |
| self.model = MimoAudio(model_path, tokenizer_path) | |
| self.asr_generator = ASRGenerator(self.model) | |
| print("[init] model ready") | |
| def transcribe(self, uploaded_audio, recorded_audio, language_choice): | |
| audio_path = uploaded_audio or recorded_audio | |
| if audio_path is None: | |
| return "", "β Error: Please upload an audio file or record from your microphone." | |
| audio_tag = LANGUAGE_TAGS.get(language_choice, "") | |
| try: | |
| print(f"Performing ASR task:") | |
| print(f" Audio: {audio_path}") | |
| print(f" Language: {language_choice} (tag='{audio_tag}')") | |
| start = time.time() | |
| transcript = self.asr_generator.transcribe(audio_path, audio_tag=audio_tag) | |
| elapsed = time.time() - start | |
| status_msg = ( | |
| f"β Transcription completed in {elapsed:.2f}s\n" | |
| f"π΅ Input audio: {os.path.basename(audio_path)}\n" | |
| f"π Language tag: {language_choice}" | |
| ) | |
| return transcript, status_msg | |
| except Exception as e: | |
| error_msg = f"β Error during transcription: {str(e)}" | |
| print(error_msg) | |
| return "", error_msg | |
| def create_interface(self): | |
| with gr.Blocks( | |
| title="MiMo-V2.5-ASR Speech Recognition", | |
| theme=gr.themes.Soft(), | |
| fill_height=True, | |
| analytics_enabled=False, | |
| ) as iface: | |
| gr.Markdown("# ποΈ MiMo-V2.5-ASR: Robust Speech Recognition") | |
| gr.Markdown( | |
| "Upload an audio file **or** record directly from your microphone. " | |
| "Supports Chinese, English, Chinese dialects, code-switch, singing, " | |
| "noisy environments, and multi-speaker scenarios." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| uploaded_audio = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| sources=["upload"], | |
| interactive=True, | |
| ) | |
| recorded_audio = gr.Audio( | |
| label="Or Record from Microphone", | |
| type="filepath", | |
| sources=["microphone"], | |
| interactive=True, | |
| ) | |
| language_choice = gr.Radio( | |
| label="Language Tag", | |
| choices=list(LANGUAGE_TAGS.keys()), | |
| value="Auto", | |
| info=( | |
| "Auto: automatic language detection (recommended for " | |
| "code-switched speech). Select Chinese or English to " | |
| "bias the model toward that language." | |
| ), | |
| ) | |
| transcribe_btn = gr.Button( | |
| "π§ Transcribe", variant="primary", size="lg" | |
| ) | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Transcription", | |
| lines=10, | |
| interactive=False, | |
| placeholder="Transcription result will appear here...", | |
| show_copy_button=True, | |
| ) | |
| status = gr.Textbox( | |
| label="Status", | |
| lines=4, | |
| interactive=False, | |
| placeholder="Processing status will be shown here...", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("ποΈ Clear", size="sm") | |
| transcribe_btn.click( | |
| fn=self.transcribe, | |
| inputs=[uploaded_audio, recorded_audio, language_choice], | |
| outputs=[output_text, status], | |
| ) | |
| def clear_all(): | |
| return None, None, "Auto", "", "" | |
| clear_btn.click( | |
| fn=clear_all, | |
| outputs=[ | |
| uploaded_audio, | |
| recorded_audio, | |
| language_choice, | |
| output_text, | |
| status, | |
| ], | |
| ) | |
| return iface | |
| def main(): | |
| print("π Launch MiMo-V2.5-ASR demo...") | |
| model_path, tokenizer_path = download_models() | |
| interface = MiMoV25ASRInterface(model_path, tokenizer_path) | |
| iface = interface.create_interface() | |
| host = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") | |
| port = int(os.environ.get("GRADIO_SERVER_PORT", "7898")) | |
| print(f"π Launch service - {host}:{port}") | |
| iface.queue(default_concurrency_limit=4, max_size=20).launch( | |
| server_name=host, | |
| server_port=port, | |
| show_api=False, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |