Spaces:
Running on A100
Running on A100
| # Copyright 2025 Xiaomi Corporation. | |
| import time | |
| import random | |
| import torch | |
| import torchaudio | |
| from typing import Union | |
| from torchaudio.transforms import MelSpectrogram | |
| from transformers import ( | |
| AutoTokenizer, | |
| GenerationConfig | |
| ) | |
| from transformers.tokenization_utils_fast import PreTrainedTokenizerFast | |
| from .process_speechdata import InputSegment | |
| from ..mimo_audio_tokenizer import MiMoAudioTokenizer | |
| from .templates import asr_en_templates, asr_zh_templates | |
| from .modeling_mimo_audio import ( | |
| MiMoAudioArguments, | |
| MiMoAudioForCausalLM, | |
| MiMoSampler, | |
| MiMoStopper, | |
| ) | |
| class MimoAudio: | |
| def __init__( | |
| self, | |
| model_path: str, | |
| mimo_audio_tokenizer_path: str, | |
| device: str | None = None, | |
| ) -> None: | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.path = model_path | |
| self.mimo_audio_tokenizer_path = mimo_audio_tokenizer_path | |
| self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained( | |
| self.path | |
| ) | |
| self.padding_idx = int(self.tokenizer.pad_token_id) | |
| special_tokens = [ | |
| "<|sosp|>", | |
| "<|eosp|>", | |
| "<|empty|>", | |
| "<|Human|>", | |
| "<|SpeechLM|>", | |
| "<|sostm|>", | |
| "<|eostm|>", | |
| "<|eot|>", | |
| ] | |
| for token in special_tokens: | |
| if token not in self.tokenizer.get_vocab(): | |
| print(f"Add special tokens {token} to tokenizer.vocab") | |
| self.tokenizer.add_tokens([token], special_tokens=True) | |
| self.sosp_idx = self.tokenizer.convert_tokens_to_ids("<|sosp|>") | |
| self.eosp_idx = self.tokenizer.convert_tokens_to_ids("<|eosp|>") | |
| self.empty_token = self.tokenizer.convert_tokens_to_ids("<|empty|>") | |
| self.sostm_idx = self.tokenizer.convert_tokens_to_ids("<|sostm|>") | |
| self.eostm_idx = self.tokenizer.convert_tokens_to_ids("<|eostm|>") | |
| self.eot_idx = self.tokenizer.convert_tokens_to_ids("<|eot|>") | |
| self.im_start_idx = self.tokenizer.convert_tokens_to_ids("<|im_start|>") | |
| self.im_end_idx = self.tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| model_args = MiMoAudioArguments( | |
| model_name_or_path=self.path, | |
| sosp_idx=self.sosp_idx, | |
| eosp_idx=self.eosp_idx, | |
| empty_idx=self.empty_token, | |
| sostm_idx=self.sostm_idx, | |
| eostm_idx=self.eostm_idx, | |
| eot_idx=self.eot_idx, | |
| ) | |
| start_loading_time = time.monotonic() | |
| self.model = MiMoAudioForCausalLM.from_pretrained( | |
| self.path, | |
| args=model_args, | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": self.device}, | |
| ) | |
| self.group_size=self.model.config.group_size | |
| self.audio_channels=self.model.config.audio_channels | |
| self.delay_pattern = self.model.config.delay_pattern | |
| self.vocab_size = self.model.config.vocab_size | |
| self.speech_zeroemb_idx = self.model.speech_empty_ids | |
| self.model.eval() | |
| print( | |
| f"Model loaded in {time.monotonic() - start_loading_time:.2f} seconds, device: {self.device}" | |
| ) | |
| self.generate_kwargs = { | |
| "max_length": 8192, | |
| "eos_token_id": self.tokenizer.eos_token_id, | |
| "pad_token_id": self.tokenizer.pad_token_id, | |
| } | |
| self.default_global_sampler = MiMoSampler( | |
| do_sample=True, temperature=0.6, top_k=50, top_p=0.95 | |
| ) | |
| self.default_local_sampler = MiMoSampler( | |
| do_sample=True, temperature=0.9, top_k=50, top_p=0.95 | |
| ) | |
| self.task_sampler_configs = { | |
| "asr": { | |
| "global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0), | |
| "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95) | |
| }, | |
| } | |
| start_loading_mimo_audio_tokenizer_time = time.monotonic() | |
| self.mimo_audio_tokenizer = MiMoAudioTokenizer.from_pretrained( | |
| self.mimo_audio_tokenizer_path, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| self.mimo_audio_tokenizer.eval().to(self.device) | |
| print( | |
| f"MiMo-Audio Tokenizer loaded in {time.monotonic() - start_loading_mimo_audio_tokenizer_time:.2f} seconds, device: {self.device}" | |
| ) | |
| # Initialize mel spectrogram transform for consistent processing | |
| self.mel_transform = MelSpectrogram( | |
| sample_rate=self.mimo_audio_tokenizer.config.sampling_rate, | |
| n_fft=self.mimo_audio_tokenizer.config.nfft, | |
| hop_length=self.mimo_audio_tokenizer.config.hop_length, | |
| win_length=self.mimo_audio_tokenizer.config.window_size, | |
| f_min=self.mimo_audio_tokenizer.config.fmin, | |
| f_max=self.mimo_audio_tokenizer.config.fmax, | |
| n_mels=self.mimo_audio_tokenizer.config.n_mels, | |
| power=1.0, | |
| center=True, | |
| ).to(self.device) | |
| def get_task_sampler(self, task_name): | |
| if task_name not in self.task_sampler_configs: | |
| return { | |
| "global": self.default_global_sampler, | |
| "local": self.default_local_sampler | |
| } | |
| return self.task_sampler_configs[task_name] | |
| def wav2mel(self, wav): | |
| spec = self.mel_transform(wav[None, :]) | |
| return torch.log(torch.clip(spec, min=1e-7)).squeeze() | |
| def resample_audio_if_needed(self, wav_tensor: torch.Tensor, original_sr: int): | |
| target_sr = self.mimo_audio_tokenizer.config.sampling_rate | |
| if original_sr != target_sr: | |
| wav_tensor = torchaudio.functional.resample( | |
| wav_tensor, original_sr, target_sr | |
| ) | |
| return wav_tensor | |
| def group_by_length(self, features: torch.Tensor, lengths: torch.Tensor, max_length: int): | |
| if features.size(0) != lengths.sum().item(): | |
| raise ValueError(f"Feature size mismatch: {features.size(0)} vs {lengths.sum().item()}") | |
| split_points = [] | |
| current_sum = 0 | |
| for i, seq_len in enumerate(lengths): | |
| if current_sum + seq_len > max_length and current_sum > 0: | |
| split_points.append(i) | |
| current_sum = seq_len.item() | |
| else: | |
| current_sum += seq_len.item() | |
| # Convert split points to group sizes | |
| group_sizes = [] | |
| prev = 0 | |
| for point in split_points: | |
| group_sizes.append(point - prev) | |
| prev = point | |
| if prev < len(lengths): | |
| group_sizes.append(len(lengths) - prev) | |
| len_groups = torch.split(lengths, group_sizes) | |
| feature_sizes = [group.sum().item() for group in len_groups] | |
| feature_groups = torch.split(features, feature_sizes) | |
| return feature_groups, len_groups | |
| def encode_batch(self, input_features: torch.Tensor, input_lens: torch.Tensor, max_length: int = 256000): | |
| feature_groups, len_groups = self.group_by_length(input_features, input_lens, max_length) | |
| encoded_parts = [] | |
| for features, lengths in zip(feature_groups, len_groups): | |
| with torch.no_grad(): | |
| codes, _ = self.mimo_audio_tokenizer.encoder.encode( | |
| input_features=features.to(self.device), | |
| input_lens=lengths.to(self.device), | |
| return_codes_only=True | |
| ) | |
| encoded_parts.append(codes) | |
| return torch.cat(encoded_parts, dim=-1) | |
| def preprocess_input( | |
| self, | |
| input: Union[str, torch.Tensor], | |
| ): | |
| if isinstance(input, torch.Tensor): | |
| wav = input | |
| else: | |
| wav, sr = torchaudio.load(input) | |
| if wav.ndim == 2: | |
| wav = wav.mean(dim=0) | |
| wav = self.resample_audio_if_needed(wav, sr) | |
| wav = wav.to(self.device) | |
| # Split waveform into 30s chunks, tokenize each separately, then concatenate codes | |
| target_sr = self.mimo_audio_tokenizer.config.sampling_rate | |
| chunk_samples = 30 * target_sr | |
| n_fft = self.mimo_audio_tokenizer.config.nfft | |
| total_samples = wav.shape[-1] | |
| code_parts = [] | |
| start = 0 | |
| while start < total_samples: | |
| end = min(start + chunk_samples, total_samples) | |
| # Merge a too-short trailing chunk (would break mel reflect padding) | |
| # into the current one. | |
| if 0 < total_samples - end < n_fft: | |
| end = total_samples | |
| chunk = wav[start:end] | |
| # Zero-pad if the entire audio is shorter than n_fft. | |
| if chunk.shape[-1] < n_fft: | |
| chunk = torch.nn.functional.pad(chunk, (0, n_fft - chunk.shape[-1])) | |
| mel = self.wav2mel(chunk).transpose(0, 1) # (seq_len, n_mels) | |
| codes_chunk = self.encode_batch( | |
| input_features=mel, | |
| input_lens=torch.tensor([mel.size(0)]), | |
| ) | |
| code_parts.append(codes_chunk) | |
| start = end | |
| codes_packed = torch.cat(code_parts, dim=-1) | |
| codes = codes_packed.transpose(0, 1).detach().cpu() | |
| audio_codes = codes[:, :self.audio_channels] | |
| # Pad the sequence to be a multiple of group_size by repeating the last frame | |
| num_timesteps = audio_codes.shape[0] | |
| if num_timesteps % self.group_size != 0: | |
| padding_needed = self.group_size - (num_timesteps % self.group_size) | |
| last_tokens = audio_codes[-1:, :] # Keep dim for repeat | |
| padding_tokens = last_tokens.repeat(padding_needed, 1) | |
| audio_codes = torch.cat([audio_codes, padding_tokens], dim=0) | |
| audio_tokenized = audio_codes.reshape(-1) | |
| return audio_tokenized | |
| def get_input_ids(self, prompt): | |
| input_ids = [ | |
| seg.to_input_id( | |
| self.tokenizer, | |
| self.group_size, | |
| self.audio_channels, | |
| ) | |
| for seg in prompt | |
| ] | |
| input_ids = torch.cat(input_ids, dim=1) | |
| return input_ids.to(self.device) | |
| def get_asr_sft_prompt( | |
| self, | |
| input: Union[None, str] = None, | |
| audio_tag="", | |
| ): | |
| audio_tokenized = self.preprocess_input(input) | |
| if '<chinese>' in audio_tag: | |
| template = random.choice(asr_zh_templates) | |
| elif '<english>' in audio_tag: | |
| template = random.choice(asr_en_templates) | |
| else: | |
| template = random.choice(asr_zh_templates + asr_en_templates) | |
| lm_prompt = [ | |
| InputSegment( | |
| text=f"<|im_start|>user\n", | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.empty_token, | |
| ), | |
| InputSegment( | |
| audio=audio_tokenized, | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.empty_token, | |
| ), | |
| InputSegment( | |
| text=template, | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.empty_token, | |
| ), | |
| InputSegment( | |
| text=f"<|im_end|>\n", | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.empty_token, | |
| ), | |
| InputSegment( | |
| text=f"<|im_start|>assistant\n", | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.empty_token, | |
| ), | |
| InputSegment( | |
| text=f"<think>\n\n</think>\n{audio_tag}", | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.empty_token, | |
| ) | |
| ] | |
| input_ids = self.get_input_ids(lm_prompt) | |
| return input_ids | |
| def forward( | |
| self, | |
| input_ids, | |
| stopping_criteria=None, | |
| min_new_tokens=0, | |
| max_new_tokens=8192, | |
| task_name=None, | |
| ): | |
| task_sampler = self.get_task_sampler(task_name) | |
| generation_kwargs = self.generate_kwargs.copy() | |
| generation_config = GenerationConfig(**generation_kwargs) | |
| input_ids = input_ids.T.reshape(1, -1) # [B, flattened(T, audio_channels + 1)] | |
| prompt_length = input_ids.shape[1] // (self.audio_channels+1) | |
| max_length = prompt_length // self.group_size + max_new_tokens | |
| min_length = prompt_length // self.group_size + min_new_tokens | |
| if stopping_criteria is not None: | |
| for criterion in stopping_criteria: | |
| if isinstance(criterion, MiMoStopper): | |
| criterion.max_length = max_length | |
| criterion.min_length = min_length | |
| generated_ids = self.model.generate( | |
| input_ids, | |
| generation_config, | |
| stopping_criteria=stopping_criteria, | |
| global_sampler=task_sampler["global"], | |
| local_sampler=task_sampler["local"], | |
| ) | |
| generated_ids = generated_ids.int().cpu().reshape(-1, self.audio_channels+1).T[:, prompt_length:] | |
| text = generated_ids[0, ::self.group_size][:-1] | |
| detokenized_text = self.tokenizer.decode(text, skip_special_tokens=False).strip().replace("<|empty|>", "").replace("<|eot|>", "").replace("<|eostm|>", "") | |
| print("Text channel:\t", detokenized_text) | |
| return detokenized_text | |
| def asr_sft(self, audio, audio_tag=""): | |
| stopping_criteria = [ | |
| MiMoStopper( | |
| stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx], | |
| group_size=self.group_size, | |
| audio_channels=self.audio_channels, | |
| ) | |
| ] | |
| input_ids = self.get_asr_sft_prompt(audio, audio_tag) | |
| result = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="asr") | |
| if '<chinese>' in result or '<english>' in result: | |
| result = result.replace('<chinese>', '').replace('<english>', '').strip() | |
| return result | |