MarkDaniel212 commited on
Commit
2c4c098
·
verified ·
1 Parent(s): 51637ef

Initial Docker-based ASR demo (app.py + src + requirements)

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ assets/models/
2
+ .venv/
3
+ __pycache__/
4
+ *.pyc
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ wget \
7
+ build-essential \
8
+ libssl-dev \
9
+ zlib1g-dev \
10
+ libbz2-dev \
11
+ libreadline-dev \
12
+ libsqlite3-dev \
13
+ libffi-dev \
14
+ libncursesw5-dev \
15
+ xz-utils \
16
+ tk-dev \
17
+ liblzma-dev \
18
+ git \
19
+ ca-certificates \
20
+ curl \
21
+ ffmpeg \
22
+ && rm -rf /var/lib/apt/lists/*
23
+
24
+ WORKDIR /tmp
25
+ RUN wget https://www.python.org/ftp/python/3.12.7/Python-3.12.7.tgz \
26
+ && tar -xvf Python-3.12.7.tgz \
27
+ && cd Python-3.12.7 \
28
+ && ./configure --enable-optimizations --prefix=/usr/local \
29
+ && make -j$(nproc) \
30
+ && make altinstall \
31
+ && cd .. \
32
+ && rm -rf Python-3.12.7 Python-3.12.7.tgz
33
+
34
+ RUN python3.12 -m ensurepip --upgrade
35
+ RUN python3.12 -m pip install --upgrade pip
36
+
37
+ RUN useradd -m -u 1000 user
38
+ USER user
39
+ ENV HOME=/home/user
40
+ ENV PATH="$HOME/.local/bin:$PATH"
41
+
42
+ WORKDIR /app
43
+
44
+ COPY --chown=user ./requirements.txt requirements.txt
45
+ RUN python3.12 -m pip install --no-cache-dir --upgrade -r requirements.txt
46
+
47
+ COPY --chown=user . /app
48
+
49
+ CMD ["python3.12", "app.py"]
README.md CHANGED
@@ -3,13 +3,13 @@ title: MiMo V2.5 ASR
3
  emoji: 🦀
4
  colorFrom: blue
5
  colorTo: green
6
- sdk: gradio
7
- sdk_version: 6.13.0
8
- python_version: '3.12'
9
- app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
  short_description: Leading ASR models from Xiaomi MiMo
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
3
  emoji: 🦀
4
  colorFrom: blue
5
  colorTo: green
6
+ sdk: docker
7
+ app_port: 7898
 
 
8
  pinned: false
9
  license: apache-2.0
10
  short_description: Leading ASR models from Xiaomi MiMo
11
  ---
12
 
13
+ MiMo-V2.5-ASR: Robust Speech Recognition across languages, dialects, and complex acoustic scenarios.
14
+
15
+ Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>.
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import os
3
+ import time
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from huggingface_hub import snapshot_download
8
+
9
+ from src.mimo_audio.mimo_audio import MimoAudio
10
+
11
+
12
+ MODEL_REPO = "XiaomiMiMo/MiMo-V2.5-ASR"
13
+ TOKENIZER_REPO = "XiaomiMiMo/MiMo-Audio-Tokenizer"
14
+ DOWNLOAD_ROOT = os.environ.get("MIMO_DOWNLOAD_ROOT", "assets/models")
15
+
16
+ LANGUAGE_TAGS = {
17
+ "Auto": "",
18
+ "Chinese": "<chinese>",
19
+ "English": "<english>",
20
+ }
21
+
22
+
23
+ def download_models():
24
+ os.makedirs(DOWNLOAD_ROOT, exist_ok=True)
25
+ hf_token = os.getenv("HF_TOKEN")
26
+
27
+ model_path = os.path.join(DOWNLOAD_ROOT, MODEL_REPO.replace("/", "_"))
28
+ tokenizer_path = os.path.join(DOWNLOAD_ROOT, TOKENIZER_REPO.replace("/", "_"))
29
+
30
+ print(f"[download] {MODEL_REPO} -> {model_path}")
31
+ snapshot_download(repo_id=MODEL_REPO, token=hf_token, local_dir=model_path)
32
+
33
+ print(f"[download] {TOKENIZER_REPO} -> {tokenizer_path}")
34
+ snapshot_download(repo_id=TOKENIZER_REPO, token=hf_token, local_dir=tokenizer_path)
35
+
36
+ return model_path, tokenizer_path
37
+
38
+
39
+ class ASRGenerator:
40
+ def __init__(self, model):
41
+ self.model = model
42
+
43
+ def transcribe(self, audio_path, audio_tag=""):
44
+ return self.model.asr_sft(audio_path, audio_tag=audio_tag)
45
+
46
+
47
+ class MiMoV25ASRInterface:
48
+ def __init__(self, model_path, tokenizer_path):
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ print(f"[init] device={device}")
51
+ print(f"[init] model_path={model_path}")
52
+ print(f"[init] tokenizer_path={tokenizer_path}")
53
+
54
+ self.model = MimoAudio(model_path, tokenizer_path)
55
+ self.asr_generator = ASRGenerator(self.model)
56
+ print("[init] model ready")
57
+
58
+ def transcribe(self, uploaded_audio, recorded_audio, language_choice):
59
+ audio_path = uploaded_audio or recorded_audio
60
+ if audio_path is None:
61
+ return "", "❌ Error: Please upload an audio file or record from your microphone."
62
+
63
+ audio_tag = LANGUAGE_TAGS.get(language_choice, "")
64
+
65
+ try:
66
+ print(f"Performing ASR task:")
67
+ print(f" Audio: {audio_path}")
68
+ print(f" Language: {language_choice} (tag='{audio_tag}')")
69
+
70
+ start = time.time()
71
+ transcript = self.asr_generator.transcribe(audio_path, audio_tag=audio_tag)
72
+ elapsed = time.time() - start
73
+
74
+ status_msg = (
75
+ f"✅ Transcription completed in {elapsed:.2f}s\n"
76
+ f"🎵 Input audio: {os.path.basename(audio_path)}\n"
77
+ f"🌐 Language tag: {language_choice}"
78
+ )
79
+ return transcript, status_msg
80
+
81
+ except Exception as e:
82
+ error_msg = f"❌ Error during transcription: {str(e)}"
83
+ print(error_msg)
84
+ return "", error_msg
85
+
86
+ def create_interface(self):
87
+ with gr.Blocks(title="MiMo-V2.5-ASR Speech Recognition", theme=gr.themes.Soft()) as iface:
88
+ gr.Markdown("# 🎙️ MiMo-V2.5-ASR: Robust Speech Recognition")
89
+ gr.Markdown(
90
+ "Upload an audio file **or** record directly from your microphone. "
91
+ "Supports Chinese, English, Chinese dialects, code-switch, singing, "
92
+ "noisy environments, and multi-speaker scenarios."
93
+ )
94
+
95
+ with gr.Row():
96
+ with gr.Column():
97
+ uploaded_audio = gr.Audio(
98
+ label="Upload Audio File",
99
+ type="filepath",
100
+ sources=["upload"],
101
+ interactive=True,
102
+ )
103
+ recorded_audio = gr.Audio(
104
+ label="Or Record from Microphone",
105
+ type="filepath",
106
+ sources=["microphone"],
107
+ interactive=True,
108
+ )
109
+ language_choice = gr.Radio(
110
+ label="Language Tag",
111
+ choices=list(LANGUAGE_TAGS.keys()),
112
+ value="Auto",
113
+ info=(
114
+ "Auto: automatic language detection (recommended for "
115
+ "code-switched speech). Select Chinese or English to "
116
+ "bias the model toward that language."
117
+ ),
118
+ )
119
+ transcribe_btn = gr.Button(
120
+ "🎧 Transcribe", variant="primary", size="lg"
121
+ )
122
+
123
+ with gr.Column():
124
+ output_text = gr.Textbox(
125
+ label="Transcription",
126
+ lines=10,
127
+ interactive=False,
128
+ placeholder="Transcription result will appear here...",
129
+ show_copy_button=True,
130
+ )
131
+ status = gr.Textbox(
132
+ label="Status",
133
+ lines=4,
134
+ interactive=False,
135
+ placeholder="Processing status will be shown here...",
136
+ )
137
+ with gr.Row():
138
+ clear_btn = gr.Button("🗑️ Clear", size="sm")
139
+
140
+ transcribe_btn.click(
141
+ fn=self.transcribe,
142
+ inputs=[uploaded_audio, recorded_audio, language_choice],
143
+ outputs=[output_text, status],
144
+ )
145
+
146
+ def clear_all():
147
+ return None, None, "Auto", "", ""
148
+
149
+ clear_btn.click(
150
+ fn=clear_all,
151
+ outputs=[
152
+ uploaded_audio,
153
+ recorded_audio,
154
+ language_choice,
155
+ output_text,
156
+ status,
157
+ ],
158
+ )
159
+
160
+ return iface
161
+
162
+
163
+ def main():
164
+ print("🚀 Launch MiMo-V2.5-ASR demo...")
165
+
166
+ model_path, tokenizer_path = download_models()
167
+ interface = MiMoV25ASRInterface(model_path, tokenizer_path)
168
+
169
+ iface = interface.create_interface()
170
+
171
+ host = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
172
+ port = int(os.environ.get("GRADIO_SERVER_PORT", "7898"))
173
+ print(f"🌐 Launch service - {host}:{port}")
174
+ iface.queue().launch(
175
+ server_name=host,
176
+ server_port=port,
177
+ )
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.9.0
2
+ librosa==0.11.0
3
+ scipy==1.16.1
4
+ torch @ https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl
5
+ torchaudio @ https://download.pytorch.org/whl/cu126/torchaudio-2.6.0%2Bcu126-cp312-cp312-linux_x86_64.whl
6
+ flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.2/flash_attn-2.8.2%2Bcu12torch2.6cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
7
+ transformers==4.49.0
8
+ triton==3.2.0
9
+ gradio==5.46.1
10
+ zhon==2.1.1
11
+ huggingface_hub>=0.26.0
src/mimo_audio/mimo_audio.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import time
3
+ import random
4
+ import torch
5
+ import torchaudio
6
+
7
+ from typing import Union
8
+ from torchaudio.transforms import MelSpectrogram
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ GenerationConfig
12
+ )
13
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
14
+
15
+ from .process_speechdata import InputSegment
16
+ from ..mimo_audio_tokenizer import MiMoAudioTokenizer
17
+ from .templates import asr_en_templates, asr_zh_templates
18
+ from .modeling_mimo_audio import (
19
+ MiMoAudioArguments,
20
+ MiMoAudioForCausalLM,
21
+ MiMoSampler,
22
+ MiMoStopper,
23
+ )
24
+
25
+
26
+ class MimoAudio:
27
+
28
+ def __init__(
29
+ self,
30
+ model_path: str,
31
+ mimo_audio_tokenizer_path: str,
32
+ device: str | None = None,
33
+ ) -> None:
34
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ self.path = model_path
37
+ self.mimo_audio_tokenizer_path = mimo_audio_tokenizer_path
38
+
39
+ self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(
40
+ self.path
41
+ )
42
+ self.padding_idx = int(self.tokenizer.pad_token_id)
43
+
44
+ special_tokens = [
45
+ "<|sosp|>",
46
+ "<|eosp|>",
47
+ "<|empty|>",
48
+ "<|Human|>",
49
+ "<|SpeechLM|>",
50
+ "<|sostm|>",
51
+ "<|eostm|>",
52
+ "<|eot|>",
53
+ ]
54
+ for token in special_tokens:
55
+ if token not in self.tokenizer.get_vocab():
56
+ print(f"Add special tokens {token} to tokenizer.vocab")
57
+ self.tokenizer.add_tokens([token], special_tokens=True)
58
+
59
+ self.sosp_idx = self.tokenizer.convert_tokens_to_ids("<|sosp|>")
60
+ self.eosp_idx = self.tokenizer.convert_tokens_to_ids("<|eosp|>")
61
+ self.empty_token = self.tokenizer.convert_tokens_to_ids("<|empty|>")
62
+ self.sostm_idx = self.tokenizer.convert_tokens_to_ids("<|sostm|>")
63
+ self.eostm_idx = self.tokenizer.convert_tokens_to_ids("<|eostm|>")
64
+ self.eot_idx = self.tokenizer.convert_tokens_to_ids("<|eot|>")
65
+ self.im_start_idx = self.tokenizer.convert_tokens_to_ids("<|im_start|>")
66
+ self.im_end_idx = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
67
+
68
+ model_args = MiMoAudioArguments(
69
+ model_name_or_path=self.path,
70
+ sosp_idx=self.sosp_idx,
71
+ eosp_idx=self.eosp_idx,
72
+ empty_idx=self.empty_token,
73
+ sostm_idx=self.sostm_idx,
74
+ eostm_idx=self.eostm_idx,
75
+ eot_idx=self.eot_idx,
76
+ )
77
+
78
+ start_loading_time = time.monotonic()
79
+ self.model = MiMoAudioForCausalLM.from_pretrained(
80
+ self.path,
81
+ args=model_args,
82
+ torch_dtype=torch.bfloat16,
83
+ device_map={"": self.device},
84
+ )
85
+
86
+ self.group_size=self.model.config.group_size
87
+ self.audio_channels=self.model.config.audio_channels
88
+ self.delay_pattern = self.model.config.delay_pattern
89
+ self.vocab_size = self.model.config.vocab_size
90
+
91
+ self.speech_zeroemb_idx = self.model.speech_empty_ids
92
+
93
+ self.model.eval()
94
+ print(
95
+ f"Model loaded in {time.monotonic() - start_loading_time:.2f} seconds, device: {self.device}"
96
+ )
97
+
98
+ self.generate_kwargs = {
99
+ "max_length": 8192,
100
+ "eos_token_id": self.tokenizer.eos_token_id,
101
+ "pad_token_id": self.tokenizer.pad_token_id,
102
+ }
103
+ self.default_global_sampler = MiMoSampler(
104
+ do_sample=True, temperature=0.6, top_k=50, top_p=0.95
105
+ )
106
+ self.default_local_sampler = MiMoSampler(
107
+ do_sample=True, temperature=0.9, top_k=50, top_p=0.95
108
+ )
109
+
110
+ self.task_sampler_configs = {
111
+ "asr": {
112
+ "global": MiMoSampler(do_sample=False, temperature=1.0, top_p=1.0),
113
+ "local": MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)
114
+ },
115
+ }
116
+
117
+ start_loading_mimo_audio_tokenizer_time = time.monotonic()
118
+ self.mimo_audio_tokenizer = MiMoAudioTokenizer.from_pretrained(self.mimo_audio_tokenizer_path)
119
+
120
+ self.mimo_audio_tokenizer.eval().bfloat16().to(self.device)
121
+ print(
122
+ f"MiMo-Audio Tokenizer loaded in {time.monotonic() - start_loading_mimo_audio_tokenizer_time:.2f} seconds, device: {self.device}"
123
+ )
124
+
125
+ # Initialize mel spectrogram transform for consistent processing
126
+ self.mel_transform = MelSpectrogram(
127
+ sample_rate=self.mimo_audio_tokenizer.config.sampling_rate,
128
+ n_fft=self.mimo_audio_tokenizer.config.nfft,
129
+ hop_length=self.mimo_audio_tokenizer.config.hop_length,
130
+ win_length=self.mimo_audio_tokenizer.config.window_size,
131
+ f_min=self.mimo_audio_tokenizer.config.fmin,
132
+ f_max=self.mimo_audio_tokenizer.config.fmax,
133
+ n_mels=self.mimo_audio_tokenizer.config.n_mels,
134
+ power=1.0,
135
+ center=True,
136
+ ).to(self.device)
137
+
138
+ def get_task_sampler(self, task_name):
139
+ if task_name not in self.task_sampler_configs:
140
+ return {
141
+ "global": self.default_global_sampler,
142
+ "local": self.default_local_sampler
143
+ }
144
+ return self.task_sampler_configs[task_name]
145
+
146
+ def wav2mel(self, wav):
147
+ spec = self.mel_transform(wav[None, :])
148
+ return torch.log(torch.clip(spec, min=1e-7)).squeeze()
149
+
150
+ def resample_audio_if_needed(self, wav_tensor: torch.Tensor, original_sr: int):
151
+ target_sr = self.mimo_audio_tokenizer.config.sampling_rate
152
+ if original_sr != target_sr:
153
+ wav_tensor = torchaudio.functional.resample(
154
+ wav_tensor, original_sr, target_sr
155
+ )
156
+ return wav_tensor
157
+
158
+ def group_by_length(self, features: torch.Tensor, lengths: torch.Tensor, max_length: int):
159
+ if features.size(0) != lengths.sum().item():
160
+ raise ValueError(f"Feature size mismatch: {features.size(0)} vs {lengths.sum().item()}")
161
+
162
+ split_points = []
163
+ current_sum = 0
164
+
165
+ for i, seq_len in enumerate(lengths):
166
+ if current_sum + seq_len > max_length and current_sum > 0:
167
+ split_points.append(i)
168
+ current_sum = seq_len.item()
169
+ else:
170
+ current_sum += seq_len.item()
171
+
172
+ # Convert split points to group sizes
173
+ group_sizes = []
174
+ prev = 0
175
+ for point in split_points:
176
+ group_sizes.append(point - prev)
177
+ prev = point
178
+ if prev < len(lengths):
179
+ group_sizes.append(len(lengths) - prev)
180
+
181
+ len_groups = torch.split(lengths, group_sizes)
182
+ feature_sizes = [group.sum().item() for group in len_groups]
183
+ feature_groups = torch.split(features, feature_sizes)
184
+
185
+ return feature_groups, len_groups
186
+
187
+ def encode_batch(self, input_features: torch.Tensor, input_lens: torch.Tensor, max_length: int = 256000):
188
+ feature_groups, len_groups = self.group_by_length(input_features, input_lens, max_length)
189
+
190
+ encoded_parts = []
191
+ for features, lengths in zip(feature_groups, len_groups):
192
+ with torch.no_grad():
193
+ codes, _ = self.mimo_audio_tokenizer.encoder.encode(
194
+ input_features=features.to(self.device),
195
+ input_lens=lengths.to(self.device),
196
+ return_codes_only=True
197
+ )
198
+ encoded_parts.append(codes)
199
+
200
+ return torch.cat(encoded_parts, dim=-1)
201
+
202
+ def preprocess_input(
203
+ self,
204
+ input: Union[str, torch.Tensor],
205
+ ):
206
+ if isinstance(input, torch.Tensor):
207
+ wav = input
208
+ else:
209
+ wav, sr = torchaudio.load(input)
210
+ if wav.ndim == 2:
211
+ wav = wav.mean(dim=0)
212
+ wav = self.resample_audio_if_needed(wav, sr)
213
+ wav = wav.to(self.device)
214
+
215
+ # Split waveform into 30s chunks, tokenize each separately, then concatenate codes
216
+ target_sr = self.mimo_audio_tokenizer.config.sampling_rate
217
+ chunk_samples = 30 * target_sr
218
+ n_fft = self.mimo_audio_tokenizer.config.nfft
219
+
220
+ total_samples = wav.shape[-1]
221
+ code_parts = []
222
+ start = 0
223
+ while start < total_samples:
224
+ end = min(start + chunk_samples, total_samples)
225
+ # Merge a too-short trailing chunk (would break mel reflect padding)
226
+ # into the current one.
227
+ if 0 < total_samples - end < n_fft:
228
+ end = total_samples
229
+ chunk = wav[start:end]
230
+ # Zero-pad if the entire audio is shorter than n_fft.
231
+ if chunk.shape[-1] < n_fft:
232
+ chunk = torch.nn.functional.pad(chunk, (0, n_fft - chunk.shape[-1]))
233
+ mel = self.wav2mel(chunk).transpose(0, 1) # (seq_len, n_mels)
234
+ codes_chunk = self.encode_batch(
235
+ input_features=mel,
236
+ input_lens=torch.tensor([mel.size(0)]),
237
+ )
238
+ code_parts.append(codes_chunk)
239
+ start = end
240
+
241
+ codes_packed = torch.cat(code_parts, dim=-1)
242
+ codes = codes_packed.transpose(0, 1).detach().cpu()
243
+ audio_codes = codes[:, :self.audio_channels]
244
+
245
+ # Pad the sequence to be a multiple of group_size by repeating the last frame
246
+ num_timesteps = audio_codes.shape[0]
247
+ if num_timesteps % self.group_size != 0:
248
+ padding_needed = self.group_size - (num_timesteps % self.group_size)
249
+ last_tokens = audio_codes[-1:, :] # Keep dim for repeat
250
+ padding_tokens = last_tokens.repeat(padding_needed, 1)
251
+ audio_codes = torch.cat([audio_codes, padding_tokens], dim=0)
252
+
253
+ audio_tokenized = audio_codes.reshape(-1)
254
+
255
+ return audio_tokenized
256
+
257
+ def get_input_ids(self, prompt):
258
+ input_ids = [
259
+ seg.to_input_id(
260
+ self.tokenizer,
261
+ self.group_size,
262
+ self.audio_channels,
263
+ )
264
+ for seg in prompt
265
+ ]
266
+ input_ids = torch.cat(input_ids, dim=1)
267
+ return input_ids.to(self.device)
268
+
269
+
270
+ def get_asr_sft_prompt(
271
+ self,
272
+ input: Union[None, str] = None,
273
+ audio_tag="",
274
+ ):
275
+ audio_tokenized = self.preprocess_input(input)
276
+
277
+ if '<chinese>' in audio_tag:
278
+ template = random.choice(asr_zh_templates)
279
+ elif '<english>' in audio_tag:
280
+ template = random.choice(asr_en_templates)
281
+ else:
282
+ template = random.choice(asr_zh_templates + asr_en_templates)
283
+
284
+ lm_prompt = [
285
+ InputSegment(
286
+ text=f"<|im_start|>user\n",
287
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
288
+ text_zeroemb_idx=self.empty_token,
289
+ ),
290
+ InputSegment(
291
+ audio=audio_tokenized,
292
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
293
+ text_zeroemb_idx=self.empty_token,
294
+ ),
295
+ InputSegment(
296
+ text=template,
297
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
298
+ text_zeroemb_idx=self.empty_token,
299
+ ),
300
+ InputSegment(
301
+ text=f"<|im_end|>\n",
302
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
303
+ text_zeroemb_idx=self.empty_token,
304
+ ),
305
+ InputSegment(
306
+ text=f"<|im_start|>assistant\n",
307
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
308
+ text_zeroemb_idx=self.empty_token,
309
+ ),
310
+ InputSegment(
311
+ text=f"<think>\n\n</think>\n{audio_tag}",
312
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
313
+ text_zeroemb_idx=self.empty_token,
314
+ )
315
+ ]
316
+ input_ids = self.get_input_ids(lm_prompt)
317
+ return input_ids
318
+
319
+
320
+ @torch.no_grad()
321
+ def forward(
322
+ self,
323
+ input_ids,
324
+ stopping_criteria=None,
325
+ min_new_tokens=0,
326
+ max_new_tokens=8192,
327
+ task_name=None,
328
+ ):
329
+
330
+ task_sampler = self.get_task_sampler(task_name)
331
+
332
+ generation_kwargs = self.generate_kwargs.copy()
333
+ generation_config = GenerationConfig(**generation_kwargs)
334
+
335
+ input_ids = input_ids.T.reshape(1, -1) # [B, flattened(T, audio_channels + 1)]
336
+
337
+ prompt_length = input_ids.shape[1] // (self.audio_channels+1)
338
+
339
+ max_length = prompt_length // self.group_size + max_new_tokens
340
+ min_length = prompt_length // self.group_size + min_new_tokens
341
+
342
+ if stopping_criteria is not None:
343
+ for criterion in stopping_criteria:
344
+ if isinstance(criterion, MiMoStopper):
345
+ criterion.max_length = max_length
346
+ criterion.min_length = min_length
347
+
348
+ generated_ids = self.model.generate(
349
+ input_ids,
350
+ generation_config,
351
+ stopping_criteria=stopping_criteria,
352
+ global_sampler=task_sampler["global"],
353
+ local_sampler=task_sampler["local"],
354
+ )
355
+
356
+ generated_ids = generated_ids.int().cpu().reshape(-1, self.audio_channels+1).T[:, prompt_length:]
357
+
358
+ text = generated_ids[0, ::self.group_size][:-1]
359
+ detokenized_text = self.tokenizer.decode(text, skip_special_tokens=False).strip().replace("<|empty|>", "").replace("<|eot|>", "").replace("<|eostm|>", "")
360
+ print("Text channel:\t", detokenized_text)
361
+
362
+ return detokenized_text
363
+
364
+ def asr_sft(self, audio, audio_tag=""):
365
+ stopping_criteria = [
366
+ MiMoStopper(
367
+ stop_tokens=[self.tokenizer.eos_token_id, self.im_end_idx],
368
+ group_size=self.group_size,
369
+ audio_channels=self.audio_channels,
370
+ )
371
+ ]
372
+ input_ids = self.get_asr_sft_prompt(audio, audio_tag)
373
+ result = self.forward(input_ids, stopping_criteria=stopping_criteria, task_name="asr")
374
+ if '<chinese>' in result or '<english>' in result:
375
+ result = result.replace('<chinese>', '').replace('<english>', '').strip()
376
+ return result
src/mimo_audio/modeling_mimo_audio.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import copy
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch import nn
10
+ from transformers import StoppingCriteria
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+ from transformers.generation.streamers import BaseStreamer
13
+ from transformers.generation.utils import (
14
+ GenerateOutput,
15
+ GenerationConfig,
16
+ StoppingCriteriaList,
17
+ is_deepspeed_zero3_enabled,
18
+ )
19
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
20
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
21
+ from transformers.models.qwen2.modeling_qwen2 import (
22
+ Qwen2Model,
23
+ Qwen2PreTrainedModel,
24
+ )
25
+ from transformers.utils import is_torchdynamo_compiling
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class MiMoStopper(StoppingCriteria):
32
+ def __init__(
33
+ self,
34
+ group_size: int,
35
+ audio_channels: int,
36
+ stop_tokens: list[int] | None = None,
37
+ max_length: int | None = None,
38
+ min_length: int | None = None,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.group_size = group_size
42
+ self.audio_channels = audio_channels
43
+ self.step = (audio_channels + 1) * group_size
44
+
45
+ self.stop_token_ids = set(stop_tokens or [])
46
+
47
+ self.max_length = max_length
48
+ self.min_length = min_length or 0
49
+
50
+ def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
51
+ is_done = False
52
+ cur_len = input_ids.shape[-1] // self.step
53
+
54
+ if self.max_length:
55
+ is_done |= cur_len >= self.max_length
56
+
57
+ if (self.stop_token_ids and
58
+ input_ids.shape[1] >= self.step and
59
+ cur_len >= self.min_length):
60
+ last_token = input_ids[0, -self.step].item()
61
+ is_done |= last_token in self.stop_token_ids
62
+
63
+ return torch.full(
64
+ (input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool
65
+ )
66
+
67
+
68
+ @dataclass
69
+ class MiMoSampler:
70
+ do_sample: bool | None = None
71
+ temperature: float | None = None
72
+ top_k: int | None = None
73
+ top_p: float | None = None
74
+
75
+ def process(self, scores: torch.Tensor):
76
+ if self.temperature is not None:
77
+ scores = scores / self.temperature
78
+
79
+ if self.top_k is not None and self.top_k > 0:
80
+ top_k = min(self.top_k, scores.shape[-1])
81
+ indices_to_remove = scores < torch.topk(scores, top_k)[0][:, -1]
82
+ scores = scores.masked_fill(indices_to_remove, float("-inf"))
83
+
84
+ if self.top_p is not None and 0.0 < self.top_p <= 1.0:
85
+ top_p = self.top_p if 0.0 < self.top_p <= 1.0 else 1.0
86
+ sorted_logits, sorted_indices = torch.sort(scores)
87
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
88
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
89
+ sorted_indices_to_remove[:, -1] = 0
90
+ indices_to_remove = sorted_indices_to_remove.scatter(
91
+ 1, sorted_indices, sorted_indices_to_remove
92
+ )
93
+ scores = scores.masked_fill(indices_to_remove, float("-inf"))
94
+
95
+ return scores
96
+
97
+ def sample(self, scores: torch.Tensor, removed_tokens: list[int] | None = None):
98
+ scores = self.process(scores)
99
+ for t in removed_tokens or []:
100
+ scores[:, t] = float("-inf")
101
+
102
+ if self.do_sample:
103
+ probs = scores.softmax(dim=-1)
104
+ return torch.multinomial(probs, num_samples=1).squeeze(-1)
105
+
106
+ return torch.argmax(scores, dim=-1)
107
+
108
+
109
+ @dataclass
110
+ class MiMoAudioOutput(ModelOutput):
111
+ text_logits: torch.FloatTensor | None = None
112
+ local_hidden_states: torch.FloatTensor | None = None
113
+ past_key_values: Cache | None = None
114
+ """Downcast hidden states for local transformer generation"""
115
+
116
+
117
+ @dataclass
118
+ class MiMoAudioConfig(Qwen2Config):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ speech_vocab_size: str | int = "1025-1025-129-129-129-129-129-129",
123
+ speech_zeroemb_idx: str | int = "1024-1024-128-128-128-128-128-128",
124
+ delay_pattern: str = "0-1-2-3-4-5-6-7",
125
+ head_dim: int = 128,
126
+ group_size: int = 4,
127
+ audio_channels: int = 8,
128
+ local_dim: int = 1024,
129
+ local_layers: int = 16,
130
+ local_attn_heads: int = 64,
131
+ local_ffn_dim: int = 4096,
132
+ local_attn_dropout: float = 0.1,
133
+ input_local_layers: int = 6,
134
+ input_local_dim: int | None = None,
135
+ input_full_attention: bool | None = None,
136
+ **kwargs,
137
+ ):
138
+ super().__init__(
139
+ **kwargs,
140
+ )
141
+ self.speech_vocab_size = speech_vocab_size
142
+ self.speech_zeroemb_idx = speech_zeroemb_idx
143
+ self.delay_pattern = delay_pattern
144
+
145
+ self.head_dim = head_dim
146
+
147
+ self.group_size = group_size
148
+ self.audio_channels = audio_channels
149
+
150
+ self.local_dim = local_dim
151
+ self.local_layers = local_layers
152
+ self.local_attn_heads = local_attn_heads
153
+ self.local_ffn_dim = local_ffn_dim
154
+ self.local_attn_dropout = local_attn_dropout
155
+
156
+ self.input_local_layers = input_local_layers
157
+ self.input_local_dim = input_local_dim or local_dim
158
+
159
+ self.input_full_attention = input_full_attention
160
+
161
+ def _parse_maybe_list(self, value: str | int, length: int) -> List[int]:
162
+ if isinstance(value, str) and "-" in value:
163
+ return [int(s) for s in value.split("-")]
164
+ return [int(value)] * length
165
+
166
+ def parsed_speech_empty_ids(self):
167
+ return self._parse_maybe_list(self.speech_zeroemb_idx, self.audio_channels)
168
+
169
+ def parsed_speech_vocab_sizes(self):
170
+ return self._parse_maybe_list(self.speech_vocab_size, self.audio_channels)
171
+
172
+ def parsed_delay_pattern(self):
173
+ return self._parse_maybe_list(self.delay_pattern, self.audio_channels)
174
+
175
+ def local_config(self):
176
+ config = copy.deepcopy(self)
177
+
178
+ config.hidden_size = self.local_dim
179
+ config.num_hidden_layers = self.local_layers
180
+ config.num_attention_heads = self.local_attn_heads
181
+ config.num_key_value_heads = self.local_attn_heads
182
+ config.head_dim = config.hidden_size // self.local_attn_heads
183
+ config.intermediate_size = self.local_ffn_dim
184
+ config.attention_dropout = self.local_attn_dropout
185
+
186
+ return config
187
+
188
+ def input_local_config(self):
189
+ config = copy.deepcopy(self)
190
+
191
+ config.hidden_size = self.input_local_dim
192
+ config.num_hidden_layers = self.input_local_layers
193
+ config.num_attention_heads = self.local_attn_heads
194
+ config.num_key_value_heads = self.local_attn_heads
195
+ config.head_dim = config.hidden_size // self.local_attn_heads
196
+ config.intermediate_size = config.hidden_size * 4
197
+ config.attention_dropout = self.local_attn_dropout
198
+
199
+ return config
200
+
201
+
202
+ @dataclass
203
+ class MiMoAudioArguments:
204
+ model_name_or_path: str
205
+ sosp_idx: int
206
+ eosp_idx: int
207
+ sostm_idx: int
208
+ eostm_idx: int
209
+ eot_idx: int
210
+ empty_idx: int
211
+
212
+ def to_dict(self):
213
+ return {
214
+ "model_name_or_path": self.model_name_or_path,
215
+ "sosp_idx": self.sosp_idx,
216
+ "eosp_idx": self.eosp_idx,
217
+ "sostm_idx": self.sostm_idx,
218
+ "eostm_idx": self.eostm_idx,
219
+ "eot_idx": self.eot_idx,
220
+ "empty_idx": self.empty_idx,
221
+ }
222
+
223
+
224
+ class MiMoAudioForCausalLM(Qwen2PreTrainedModel):
225
+ def __init__(
226
+ self,
227
+ config: MiMoAudioConfig | Qwen2Config,
228
+ args: MiMoAudioArguments | dict,
229
+ ):
230
+ super().__init__(config)
231
+ config = (
232
+ MiMoAudioConfig(**vars(config))
233
+ if isinstance(config, Qwen2Config)
234
+ else config
235
+ )
236
+ args = MiMoAudioArguments(**args) if isinstance(args, dict) else args
237
+ self.config = config
238
+ self.args = args
239
+
240
+ self.model = Qwen2Model(config)
241
+
242
+ self.speech_vocab_sizes = config.parsed_speech_vocab_sizes()
243
+ self.speech_empty_ids = config.parsed_speech_empty_ids()
244
+ self.delay_pattern = config.parsed_delay_pattern()
245
+
246
+ self.group_size = config.group_size
247
+ self.audio_channels = config.audio_channels
248
+
249
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
250
+
251
+ # Construct local transformer
252
+ self.local_config = config.local_config()
253
+ self.local_transformer = Qwen2Model(self.local_config)
254
+ self.local_transformer.embed_tokens = None
255
+
256
+ # Add input local transformer if configured
257
+ self.input_local_config = config.input_local_config()
258
+ self.input_local_transformer = Qwen2Model(self.input_local_config)
259
+ self.input_local_transformer.embed_tokens = None
260
+
261
+ self.local_transformer_lm_heads = nn.ModuleList(
262
+ [
263
+ nn.Linear(
264
+ self.local_config.hidden_size,
265
+ self.speech_vocab_sizes[i],
266
+ bias=False,
267
+ )
268
+ for i in range(self.audio_channels)
269
+ ]
270
+ )
271
+
272
+ self.speech_embeddings = nn.ModuleList(
273
+ [
274
+ nn.Embedding(
275
+ self.speech_vocab_sizes[i],
276
+ self.input_local_config.hidden_size,
277
+ padding_idx=self.speech_empty_ids[i],
278
+ )
279
+ for i in range(self.audio_channels)
280
+ ]
281
+ )
282
+
283
+ if self.input_local_config.hidden_size != self.local_config.hidden_size:
284
+ self.speech_embeddings_to_local = nn.Linear(
285
+ self.input_local_config.hidden_size,
286
+ self.local_config.hidden_size,
287
+ bias=False,
288
+ )
289
+ else:
290
+ self.speech_embeddings_to_local = None
291
+
292
+ # Create speech_group_downcast_first for group_first_in_global_context
293
+ self.speech_group_downcast = nn.Linear(
294
+ self.input_local_config.hidden_size * config.group_size,
295
+ config.hidden_size,
296
+ bias=False,
297
+ )
298
+
299
+ self.hidden_states_downcast = nn.Linear(
300
+ config.hidden_size,
301
+ self.local_config.hidden_size,
302
+ bias=False,
303
+ )
304
+
305
+ # Initialize weights and apply final processing
306
+ self.post_init()
307
+
308
+ def apply_input_local_transformer(self, speech_embeddings: torch.Tensor):
309
+ B, T_groups, group_size, hidden_size = speech_embeddings.shape
310
+
311
+ # Process each group independently: [B*T//group_size, group_size, hidden_size]
312
+ input_embeddings = speech_embeddings.reshape(
313
+ B * T_groups, group_size, hidden_size
314
+ )
315
+
316
+ output: BaseModelOutputWithPast = self.input_local_transformer(
317
+ inputs_embeds=input_embeddings,
318
+ return_dict=True,
319
+ is_causal=not self.config.input_full_attention, # for SDPA
320
+ )
321
+ encoded_embeddings = output.last_hidden_state
322
+
323
+ # Reshape back to original format
324
+ # [B*T//group_size, group_size, hidden_size] -> [B, T//group_size, group_size, hidden_size]
325
+ encoded_embeddings = encoded_embeddings.reshape(
326
+ B, T_groups, group_size, hidden_size
327
+ )
328
+
329
+ return encoded_embeddings
330
+
331
+ def _prepare_input_embeds(
332
+ self,
333
+ input_ids: torch.LongTensor, # [B, audio_channels + 1, new_T]
334
+ ):
335
+ B = input_ids.shape[0]
336
+
337
+ input_ids = input_ids.int()
338
+ group_size = self.config.group_size
339
+
340
+ text_input_ids = input_ids[:, 0, ::group_size]
341
+ speech_input_ids = (
342
+ input_ids[:, 1:, :]
343
+ .view(B, self.audio_channels, -1, group_size)
344
+ .transpose(1, 2)
345
+ ) # [B, T//group_size, audio_channels, group_size]
346
+
347
+ is_speech = text_input_ids == self.args.empty_idx # [B, T//group_size]
348
+
349
+ speech_embeds = torch.zeros(
350
+ (
351
+ B,
352
+ is_speech.shape[1],
353
+ group_size,
354
+ self.input_local_config.hidden_size,
355
+ ),
356
+ device=input_ids.device,
357
+ dtype=torch.bfloat16,
358
+ )
359
+
360
+ for idx in range(self.audio_channels):
361
+ cur_empty = self.speech_empty_ids[idx]
362
+ cur_embed = self.speech_embeddings[idx]
363
+ cur_speech_ids = speech_input_ids[:, :, idx, :]
364
+ cur_speech_embeds: torch.Tensor = cur_embed(cur_speech_ids)
365
+ # [B, T_groups, group_size, hidden_size]
366
+
367
+ cur_mask = cur_speech_ids == cur_empty
368
+ cur_speech_embeds.masked_fill_(cur_mask.unsqueeze(-1), 0.0)
369
+
370
+ speech_embeds += cur_speech_embeds
371
+
372
+ speech_embeds = speech_embeds * is_speech.unsqueeze(-1).unsqueeze(-1)
373
+
374
+ # Apply input local transformer if configured
375
+ speech_embeds = self.apply_input_local_transformer(speech_embeds)
376
+ speech_embeds = speech_embeds * is_speech.unsqueeze(-1).unsqueeze(-1)
377
+
378
+ T_groups = speech_embeds.shape[1]
379
+ speech_grouped_embeds: torch.Tensor = self.speech_group_downcast(
380
+ speech_embeds.view(B, T_groups, -1)
381
+ ) # [B, T_groups, hidden_size]
382
+
383
+ text_embeds: torch.Tensor = self.model.embed_tokens(text_input_ids)
384
+ text_zero_mask = text_input_ids == self.args.empty_idx
385
+ text_embeds.masked_fill_(text_zero_mask.unsqueeze(-1), 0.0)
386
+
387
+ return text_embeds + speech_grouped_embeds
388
+
389
+ def forward(
390
+ self,
391
+ input_ids: torch.LongTensor, # [B, audio_channels + 1, new_T]
392
+ attention_mask: torch.Tensor, # [B, T_group]
393
+ position_ids: torch.LongTensor, # [B, new_T_group]
394
+ past_key_values: Cache | None = None,
395
+ cache_position: torch.LongTensor | None = None, # [new_T_group]
396
+ **_kwargs,
397
+ ):
398
+ inputs_embeds = self._prepare_input_embeds(input_ids)
399
+
400
+ outputs: BaseModelOutputWithPast = self.model(
401
+ attention_mask=attention_mask,
402
+ position_ids=position_ids,
403
+ past_key_values=past_key_values,
404
+ inputs_embeds=inputs_embeds,
405
+ use_cache=True,
406
+ return_dict=True,
407
+ cache_position=cache_position,
408
+ )
409
+ hidden_states = outputs.last_hidden_state # [B, new_T_group, hidden_size]
410
+
411
+ text_logits: torch.Tensor = self.lm_head(
412
+ hidden_states[:, -1:, :]
413
+ ) # [B, 1, vocab_size]
414
+ shift_hidden_states: torch.Tensor = self.hidden_states_downcast(
415
+ hidden_states[:, -1:, :]
416
+ ) # [B, 1, hidden_size]
417
+
418
+ return MiMoAudioOutput(
419
+ text_logits=text_logits,
420
+ local_hidden_states=shift_hidden_states,
421
+ past_key_values=outputs.past_key_values,
422
+ )
423
+
424
+ def local_forward(
425
+ self,
426
+ local_embeds: torch.FloatTensor, # [B, 1, hidden_size]
427
+ tokens_dtype: torch.dtype,
428
+ tokens_device: torch.device,
429
+ local_sampler: MiMoSampler | None = None,
430
+ ):
431
+ B = local_embeds.shape[0]
432
+ delay_iters = self.group_size + max(self.delay_pattern)
433
+ past_key_values = DynamicCache()
434
+ local_tokens = torch.zeros(
435
+ (B, self.group_size, self.audio_channels),
436
+ dtype=tokens_dtype,
437
+ device=tokens_device,
438
+ )
439
+ if local_sampler is None:
440
+ local_sampler = MiMoSampler()
441
+
442
+ for t in range(delay_iters):
443
+ output: BaseModelOutputWithPast = self.local_transformer(
444
+ inputs_embeds=local_embeds,
445
+ past_key_values=past_key_values,
446
+ return_dict=True,
447
+ use_cache=True,
448
+ )
449
+ hidden_state = output.last_hidden_state
450
+ past_key_values = output.past_key_values
451
+
452
+ local_embeds = torch.zeros_like(local_embeds)
453
+ for idx in range(self.audio_channels):
454
+ cur_start = self.delay_pattern[idx]
455
+ cur_end = cur_start + self.group_size
456
+ cur_empty = self.speech_empty_ids[idx]
457
+ if cur_start <= t < cur_end:
458
+ cur_lm_head = self.local_transformer_lm_heads[idx]
459
+ cur_scores: torch.Tensor = cur_lm_head(hidden_state)[:, -1, :]
460
+ # [B, vocab_size]
461
+ cur_token = local_sampler.sample(
462
+ cur_scores,
463
+ [cur_empty],
464
+ )
465
+
466
+ local_tokens[:, t - cur_start, idx] = cur_token
467
+ cur_input_embed = self.speech_embeddings[idx](
468
+ cur_token.unsqueeze(1)
469
+ )
470
+ if self.speech_embeddings_to_local is not None:
471
+ cur_input_embed = self.speech_embeddings_to_local(
472
+ cur_input_embed
473
+ )
474
+ local_embeds += cur_input_embed
475
+
476
+ return local_tokens # [B, group_size, audio_channels]
477
+
478
+ def _prepare_attention_mask(
479
+ self, inputs: torch.Tensor, input_ids_length: int
480
+ ) -> torch.Tensor:
481
+ # No information for attention mask inference -> return default attention mask
482
+ return torch.ones(
483
+ (inputs.shape[0], input_ids_length),
484
+ dtype=torch.bool,
485
+ device=inputs.device,
486
+ )
487
+
488
+ def prepare_inputs_for_generation(
489
+ self,
490
+ input_ids: torch.LongTensor,
491
+ past_key_values: Optional[Cache] = None,
492
+ attention_mask: Optional[torch.LongTensor] = None,
493
+ inputs_embeds: Optional[torch.FloatTensor] = None,
494
+ cache_position: Optional[torch.LongTensor] = None,
495
+ **kwargs,
496
+ ):
497
+ """
498
+ Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
499
+ slicing inputs given the existing cache.
500
+
501
+ See the forward pass in the model documentation for expected arguments (different models might have different
502
+ requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
503
+ """
504
+
505
+ # 1. Handle BC:
506
+ model_inputs = {}
507
+ input_ids = input_ids.reshape(
508
+ input_ids.shape[0], -1, (self.audio_channels + 1) * self.config.group_size
509
+ ).transpose(1, 2) # [B, audio_channels*group_size, T]
510
+ # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
511
+ if self._supports_cache_class:
512
+ model_inputs["cache_position"] = cache_position
513
+ # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
514
+ # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
515
+ # (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
516
+ elif cache_position is None:
517
+ past_length = (
518
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
519
+ )
520
+ cache_position = torch.arange(
521
+ past_length,
522
+ input_ids.shape[2],
523
+ dtype=torch.long,
524
+ device=input_ids.device,
525
+ )
526
+
527
+ # 2. Generic cache-dependent input preparation
528
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
529
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
530
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
531
+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
532
+ if past_key_values is not None:
533
+ model_inputs["past_key_values"] = past_key_values
534
+ if (
535
+ inputs_embeds is not None or cache_position[-1] >= input_ids.shape[2]
536
+ ): # Exception 1 or Exception 3
537
+ input_ids = input_ids[:, :, -cache_position.shape[0] :]
538
+ elif (
539
+ input_ids.shape[2] != cache_position.shape[0]
540
+ ): # Default case (the "else", a no op, is Exception 2)
541
+ input_ids = input_ids[:, :, cache_position]
542
+
543
+ # 3. Prepare base model inputs
544
+ input_ids_key = (
545
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
546
+ )
547
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
548
+ if not self.config.is_encoder_decoder:
549
+ if inputs_embeds is not None and cache_position[0] == 0:
550
+ model_inputs[input_ids_key] = None
551
+ model_inputs["inputs_embeds"] = inputs_embeds
552
+ else:
553
+ # `clone` calls in this function ensure a consistent stride. See #32227
554
+ model_inputs[input_ids_key] = input_ids.clone(
555
+ memory_format=torch.contiguous_format
556
+ )
557
+ model_inputs["inputs_embeds"] = None
558
+ else:
559
+ model_inputs[input_ids_key] = input_ids.clone(
560
+ memory_format=torch.contiguous_format
561
+ )
562
+
563
+ # 4. Create missing `position_ids` on the fly
564
+ if attention_mask is not None and kwargs.get("position_ids") is None:
565
+ position_ids = attention_mask.long().cumsum(-1) - 1
566
+ position_ids.masked_fill_(attention_mask == 0, 1)
567
+ kwargs["position_ids"] = (
568
+ position_ids # placed in kwargs for further processing (see below)
569
+ )
570
+
571
+ # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
572
+ for model_input_name in ["position_ids", "token_type_ids"]:
573
+ model_input: torch.Tensor = kwargs.get(model_input_name)
574
+ if model_input is not None:
575
+ if past_key_values:
576
+ model_input = model_input[:, -input_ids.shape[2] :]
577
+ model_input = model_input.clone(
578
+ memory_format=torch.contiguous_format
579
+ )
580
+ model_inputs[model_input_name] = model_input
581
+
582
+ if attention_mask is not None:
583
+ model_inputs["attention_mask"] = attention_mask
584
+
585
+ # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
586
+ for key, value in kwargs.items():
587
+ if key not in model_inputs:
588
+ model_inputs[key] = value
589
+
590
+ if model_inputs[input_ids_key] is not None:
591
+ model_inputs[input_ids_key] = (
592
+ cast(torch.Tensor, model_inputs[input_ids_key])
593
+ .transpose(1, 2)
594
+ .reshape(input_ids.shape[0], -1, (self.audio_channels + 1))
595
+ .transpose(1, 2)
596
+ ) # [B, audio_channels, T*group_size]
597
+
598
+ # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
599
+ model_inputs.pop("labels", None)
600
+ return model_inputs
601
+
602
+ def _get_initial_cache_position(self, input_ids: torch.Tensor, model_kwargs: dict):
603
+ """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
604
+ # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
605
+ if "inputs_embeds" in model_kwargs:
606
+ cache_position = (
607
+ torch.ones_like(
608
+ model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64
609
+ ).cumsum(0)
610
+ - 1
611
+ )
612
+ else:
613
+ cache_position = (
614
+ torch.ones(
615
+ (
616
+ input_ids.shape[1]
617
+ // (self.audio_channels + 1)
618
+ // self.config.group_size,
619
+ ),
620
+ dtype=torch.int64,
621
+ device=input_ids.device,
622
+ ).cumsum(0)
623
+ - 1
624
+ )
625
+
626
+ past_length = 0
627
+ if model_kwargs.get("past_key_values") is not None:
628
+ cache = model_kwargs["past_key_values"]
629
+ past_length = 0
630
+ if not isinstance(cache, Cache):
631
+ past_length = cache[0][0].shape[2]
632
+ elif (
633
+ hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None
634
+ ):
635
+ past_length = cache.get_seq_length()
636
+
637
+ # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
638
+ # end-to-end compilation will yield bad results because `cache_position` will be incorrect.
639
+ if not is_torchdynamo_compiling():
640
+ cache_position = cache_position[past_length:]
641
+
642
+ model_kwargs["cache_position"] = cache_position
643
+
644
+ return model_kwargs
645
+
646
+ @torch.inference_mode()
647
+ def generate(
648
+ self,
649
+ inputs: torch.Tensor | None = None,
650
+ generation_config: GenerationConfig | None = None,
651
+ stopping_criteria: StoppingCriteriaList | list | None = None,
652
+ streamer: BaseStreamer | None = None,
653
+ synced_gpus: bool | None = None,
654
+ global_sampler: MiMoSampler | None = None,
655
+ local_sampler: MiMoSampler | None = None,
656
+ warmup_run: bool | None = None,
657
+ **kwargs,
658
+ ) -> Union[GenerateOutput, torch.LongTensor]:
659
+ generation_config, model_kwargs = self._prepare_generation_config(
660
+ generation_config, **kwargs
661
+ )
662
+
663
+ self._validate_model_kwargs(model_kwargs.copy())
664
+
665
+ # 2. Set generation parameters if not already defined
666
+ if synced_gpus is None:
667
+ if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
668
+ synced_gpus = True
669
+ else:
670
+ synced_gpus = False
671
+
672
+ # 3. Define model inputs
673
+ input_ids, _model_input_name, model_kwargs = self._prepare_model_inputs(
674
+ inputs, generation_config.bos_token_id, model_kwargs
675
+ )
676
+ input_ids_length = input_ids.shape[-1]
677
+ input_ids_length //= self.group_size * (self.audio_channels + 1)
678
+
679
+ if streamer is not None:
680
+ streamer.put(input_ids.cpu())
681
+
682
+ if "attention_mask" not in model_kwargs:
683
+ model_kwargs["attention_mask"] = self._prepare_attention_mask(
684
+ inputs, input_ids_length
685
+ )
686
+
687
+ device = input_ids.device
688
+ self._prepare_special_tokens(generation_config, True, device=device)
689
+
690
+ model_kwargs["use_cache"] = True
691
+ model_kwargs["past_key_values"] = DynamicCache()
692
+
693
+ prepared_stopping_criteria = StoppingCriteriaList(
694
+ stopping_criteria if stopping_criteria is not None else []
695
+ )
696
+ prepared_stopping_criteria.append(
697
+ MiMoStopper(
698
+ self.group_size,
699
+ self.audio_channels,
700
+ max_length=generation_config.max_length,
701
+ )
702
+ )
703
+ stance = "default" if warmup_run else "eager_on_recompile"
704
+ with torch.compiler.set_stance(stance):
705
+ return self.slm_sample(
706
+ input_ids,
707
+ stopping_criteria=prepared_stopping_criteria,
708
+ generation_config=generation_config,
709
+ synced_gpus=synced_gpus,
710
+ streamer=streamer,
711
+ global_sampler=global_sampler,
712
+ local_sampler=local_sampler,
713
+ **model_kwargs,
714
+ )
715
+
716
+ def slm_sample(
717
+ self,
718
+ input_ids: torch.LongTensor,
719
+ stopping_criteria: StoppingCriteriaList,
720
+ generation_config: GenerationConfig,
721
+ synced_gpus: bool,
722
+ streamer: BaseStreamer | None,
723
+ global_sampler: MiMoSampler | None = None,
724
+ local_sampler: MiMoSampler | None = None,
725
+ **model_kwargs,
726
+ ) -> torch.LongTensor:
727
+ max_length = generation_config.max_length
728
+
729
+ B, cur_len = input_ids.shape
730
+ cur_len //= self.group_size * (self.audio_channels + 1)
731
+ initial_len = cur_len
732
+ this_peer_finished = False
733
+ unfinished_sequences = torch.ones(B, dtype=torch.long, device=input_ids.device)
734
+
735
+ min_length = 0
736
+ stop_token_ids = set()
737
+ for criterion in stopping_criteria:
738
+ if isinstance(criterion, MiMoStopper):
739
+ if criterion.min_length is not None:
740
+ min_length = max(min_length, criterion.min_length)
741
+ stop_token_ids.update(criterion.stop_token_ids)
742
+
743
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
744
+
745
+ while self._has_unfinished_sequences(
746
+ this_peer_finished,
747
+ synced_gpus,
748
+ device=input_ids.device,
749
+ cur_len=cur_len,
750
+ max_length=max_length,
751
+ ):
752
+ # prepare model inputs
753
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
754
+
755
+ # forward pass to get next token
756
+ if (
757
+ cast(torch.Tensor, model_inputs["input_ids"]).shape[2]
758
+ != self.group_size
759
+ ):
760
+ # prefill run
761
+ with torch.compiler.set_stance("force_eager"):
762
+ outputs: MiMoAudioOutput = self(**model_inputs)
763
+ else:
764
+ outputs: MiMoAudioOutput = self(**model_inputs)
765
+
766
+ if synced_gpus and this_peer_finished:
767
+ continue # don't waste resources running the code we don't need
768
+
769
+ text_logits: torch.Tensor = outputs.text_logits[:, -1, :].clone()
770
+ # [B, vocab_size]
771
+
772
+ removed_tokens = None
773
+ if cur_len < min_length:
774
+ removed_tokens = list(stop_token_ids)
775
+
776
+ next_text_tokens = global_sampler.sample(text_logits, removed_tokens=removed_tokens)
777
+ # [B]
778
+
779
+ local_hidden_states = outputs.local_hidden_states
780
+
781
+ # Only Supports batch_size=1 here
782
+ if next_text_tokens[0] != self.args.empty_idx:
783
+ zero_embed_tensor = torch.tensor(
784
+ self.speech_empty_ids,
785
+ device=next_text_tokens.device,
786
+ dtype=input_ids.dtype,
787
+ )
788
+ next_speech_tokens = zero_embed_tensor.view(
789
+ 1, 1, self.audio_channels
790
+ ).expand(B, self.config.group_size, -1)
791
+ else:
792
+ next_speech_tokens = self.local_forward(
793
+ local_embeds=local_hidden_states,
794
+ tokens_dtype=next_text_tokens.dtype,
795
+ tokens_device=next_text_tokens.device,
796
+ local_sampler=local_sampler,
797
+ )
798
+
799
+ next_text_tokens = next_text_tokens.reshape(B, 1, 1).expand(
800
+ -1, self.group_size, -1
801
+ ) # [B, group_size, 1]
802
+
803
+ # generate speech tokens
804
+ next_tokens = torch.cat(
805
+ (next_text_tokens, next_speech_tokens), dim=-1
806
+ ).reshape(B, -1) # [B, group_size * (audio_channels + 1)]
807
+
808
+ input_ids = torch.cat(
809
+ [input_ids, next_tokens], dim=-1
810
+ ) # [B, T*group_size*vq]
811
+
812
+ if streamer is not None:
813
+ streamer.put(next_tokens.cpu())
814
+ model_kwargs = self._update_model_kwargs_for_generation(
815
+ outputs,
816
+ model_kwargs,
817
+ is_encoder_decoder=self.config.is_encoder_decoder,
818
+ )
819
+
820
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
821
+ input_ids, None
822
+ )
823
+ this_peer_finished = unfinished_sequences.max() == 0
824
+ cur_len += 1
825
+
826
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
827
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
828
+ del outputs
829
+
830
+ if streamer is not None:
831
+ streamer.end()
832
+
833
+ input_ids = input_ids[:B]
834
+
835
+ return input_ids
src/mimo_audio/process_speechdata.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corporation.
3
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from typing import Tuple, Union, List
8
+
9
+
10
+ class InputSegment:
11
+
12
+ def __init__(
13
+ self,
14
+ text: str = "",
15
+ audio: torch.Tensor = None,
16
+ tokenized_text: torch.Tensor = None,
17
+ speech_zeroemb_idx: Union[int, List[int]] = 1024,
18
+ text_zeroemb_idx: int = 152067,
19
+ add_sosp_eosp=True,
20
+ ) -> None:
21
+ has_text = text is not None
22
+ has_tokenized_text = tokenized_text is not None
23
+ assert has_text or has_tokenized_text, "Text or tokenized text must be provided"
24
+
25
+ self.audio = audio
26
+ self.text = text
27
+ self.tokenized_text = tokenized_text
28
+ self.speech_zeroemb_idx = speech_zeroemb_idx
29
+ self.text_zeroemb_idx = text_zeroemb_idx
30
+ self.add_sosp_eosp = add_sosp_eosp
31
+
32
+ @staticmethod
33
+ def insert_between(tensor, i, value=-1):
34
+ return torch.scatter(
35
+ torch.full(
36
+ (1, tensor.shape[1] + (tensor.shape[1] - 1) * i + i),
37
+ value,
38
+ dtype=tensor.dtype,
39
+ ),
40
+ 1,
41
+ torch.arange(0, tensor.shape[1], dtype=torch.int64)[None] * (i + 1),
42
+ tensor,
43
+ )
44
+
45
+ def to_input_id(
46
+ self,
47
+ tokenizer,
48
+ group_size: int,
49
+ audio_channels: int = 8,
50
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
51
+ if self.audio is None:
52
+ if self.tokenized_text is None:
53
+ tokenized_text = tokenizer(
54
+ self.text,
55
+ return_tensors="pt",
56
+ truncation=True,
57
+ max_length=999999,
58
+ padding=False,
59
+ add_special_tokens=False,
60
+ )["input_ids"].int()
61
+ else:
62
+ tokenized_text = self.tokenized_text.unsqueeze(0)
63
+
64
+
65
+ if group_size > 1:
66
+ tokenized_text = self.insert_between(
67
+ tokenized_text, group_size - 1, value=-100
68
+ )
69
+
70
+
71
+ if isinstance(self.speech_zeroemb_idx, list):
72
+ audio_part_input_id = torch.zeros((audio_channels, tokenized_text.shape[1]), dtype=torch.int)
73
+ for i, idx in enumerate(self.speech_zeroemb_idx):
74
+ audio_part_input_id[i, :] = idx
75
+ else:
76
+ audio_part_input_id = torch.full(
77
+ (audio_channels, tokenized_text.shape[1]), self.speech_zeroemb_idx, dtype=torch.int
78
+ )
79
+
80
+
81
+ else:
82
+ sosp_token = (
83
+ tokenizer.convert_tokens_to_ids("<|sosp|>")
84
+ if self.add_sosp_eosp
85
+ else None
86
+ )
87
+ eosp_token = (
88
+ tokenizer.convert_tokens_to_ids("<|eosp|>")
89
+ if self.add_sosp_eosp
90
+ else None
91
+ )
92
+ audio_part = self.audio.reshape(-1, audio_channels).T # [audio_channels, seqlen]
93
+
94
+ assert (
95
+ audio_part.shape[1] % group_size == 0
96
+ ), f"Audio shape {audio_part.shape} is not divisible by group_size {group_size}"
97
+
98
+
99
+ text_len = audio_part.shape[1] // group_size
100
+ empty_token = self.text_zeroemb_idx
101
+ if empty_token is None:
102
+ empty_token = tokenizer.eod
103
+ tokenized_text = torch.full((1, text_len), empty_token, dtype=torch.int)
104
+
105
+ tokenized_text = (
106
+ torch.cat(
107
+ [
108
+ torch.tensor([[sosp_token]], dtype=torch.int),
109
+ tokenized_text,
110
+ torch.tensor([[eosp_token]], dtype=torch.int),
111
+ ],
112
+ dim=1,
113
+ )
114
+ if self.add_sosp_eosp
115
+ else tokenized_text
116
+ )
117
+ tokenized_text = self.insert_between(
118
+ tokenized_text, group_size - 1, value=-100
119
+ )
120
+
121
+
122
+ if self.add_sosp_eosp:
123
+ if isinstance(self.speech_zeroemb_idx, list):
124
+ sosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int)
125
+ eosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int)
126
+ for i, idx in enumerate(self.speech_zeroemb_idx):
127
+ sosp_part[i, :] = idx
128
+ eosp_part[i, :] = idx
129
+ audio_part_input_id = torch.cat([sosp_part, audio_part, eosp_part], dim=1)
130
+ else:
131
+ audio_part_input_id = torch.cat(
132
+ [
133
+ torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int),
134
+ audio_part,
135
+ torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int),
136
+ ],
137
+ dim=1,
138
+ )
139
+ else:
140
+ audio_part_input_id = audio_part
141
+
142
+
143
+
144
+ input_ids = torch.cat(
145
+ [tokenized_text, audio_part_input_id], dim=0
146
+ ) # [n_rvq + 1, seqlen]
147
+
148
+
149
+ return input_ids
150
+
151
+
152
+ class StreamingInputSegment:
153
+ def __init__(
154
+ self,
155
+ text: str = "",
156
+ audio: torch.Tensor = None,
157
+ tokenized_text: torch.Tensor = None,
158
+ speech_zeroemb_idx: Union[int, List[int]] = 1024,
159
+ text_zeroemb_idx: int = 152067,
160
+ text_segment_size: int = 5,
161
+ audio_segment_size: int = 5,
162
+ tokenizer=None,
163
+ group_size=None,
164
+ audio_channels=None,
165
+ ) -> None:
166
+ has_text = text is not None
167
+ has_tokenized_text = tokenized_text is not None
168
+ assert has_text or has_tokenized_text, "Text or tokenized text must be provided"
169
+
170
+ self.audio = audio
171
+ self.text = text
172
+ self.tokenized_text = tokenized_text
173
+ self.speech_zeroemb_idx = speech_zeroemb_idx
174
+ self.text_zeroemb_idx = text_zeroemb_idx
175
+ self.text_segment_size = text_segment_size
176
+ self.audio_segment_size = audio_segment_size
177
+ self.tokenizer = tokenizer
178
+ self.group_size = group_size
179
+ self.audio_channels = audio_channels
180
+
181
+ def to_input_id(
182
+ self,
183
+ tokenizer,
184
+ group_size: int,
185
+ audio_channels: int = 8,
186
+ ):
187
+ if self.tokenized_text is None:
188
+ tokenized_text = tokenizer(
189
+ self.text,
190
+ return_tensors="pt",
191
+ truncation=True,
192
+ max_length=999999,
193
+ padding=False,
194
+ add_special_tokens=False,
195
+ )["input_ids"].int() # [1, seqlen]
196
+ else:
197
+ tokenized_text = self.tokenized_text.unsqueeze(0)
198
+
199
+ tokenized_text = tokenized_text.squeeze(0)
200
+
201
+ text_segments = tokenized_text.split(self.text_segment_size, dim=0)
202
+ audio_segments = self.audio.split(self.audio_segment_size*group_size*audio_channels, dim=0)
203
+
204
+ tokenized_segments = []
205
+ tokenized_segments.append(
206
+ InputSegment(
207
+ text='<|sostm|>',
208
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
209
+ text_zeroemb_idx=self.text_zeroemb_idx,
210
+ ),
211
+ )
212
+
213
+
214
+ eot_tokens = tokenizer(
215
+ "<|eot|>",
216
+ return_tensors="pt",
217
+ truncation=True,
218
+ max_length=999999,
219
+ padding=False,
220
+ add_special_tokens=False,
221
+ )["input_ids"][0].to(text_segments[-1])
222
+
223
+
224
+ text_segments = text_segments[:-1] + (torch.cat([text_segments[-1], eot_tokens], dim=0),)
225
+
226
+
227
+ length = min(len(text_segments), len(audio_segments))
228
+ for i in range(length):
229
+ text_segment = text_segments[i]
230
+ audio_segment = audio_segments[i]
231
+
232
+ tokenized_segments.append(
233
+ InputSegment(
234
+ tokenized_text=text_segment,
235
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
236
+ text_zeroemb_idx=self.text_zeroemb_idx,
237
+ ),
238
+ )
239
+ tokenized_segments.append(
240
+ InputSegment(
241
+ audio=audio_segment,
242
+ add_sosp_eosp=False,
243
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
244
+ text_zeroemb_idx=self.text_zeroemb_idx,
245
+ ),
246
+ )
247
+
248
+ for j in range(length, len(text_segments)):
249
+ tokenized_segments.append(
250
+ InputSegment(
251
+ tokenized_text=text_segments[j],
252
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
253
+ text_zeroemb_idx=self.text_zeroemb_idx,
254
+ ),
255
+ )
256
+
257
+ for j in range(length, len(audio_segments)):
258
+ tokenized_segments.append(
259
+ InputSegment(
260
+ audio=audio_segments[j],
261
+ add_sosp_eosp=False,
262
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
263
+ text_zeroemb_idx=self.text_zeroemb_idx,
264
+ ),
265
+ )
266
+
267
+ tokenized_segments.append(
268
+ InputSegment(
269
+ text="<|eostm|>",
270
+ speech_zeroemb_idx=self.speech_zeroemb_idx,
271
+ text_zeroemb_idx=self.text_zeroemb_idx,
272
+ ),
273
+ )
274
+
275
+
276
+ input_ids = [
277
+ seg.to_input_id(
278
+ self.tokenizer,
279
+ self.group_size,
280
+ self.audio_channels,
281
+ )
282
+ for seg in tokenized_segments
283
+ ]
284
+
285
+
286
+
287
+ input_ids = torch.cat(input_ids, dim=1).type(torch.int64) # [n_rvq + 1, seqlen]
288
+
289
+ return input_ids
src/mimo_audio/templates.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ asr_zh_templates = [
3
+ "请将这段语音转换为文字",
4
+ "帮我识别这个音频文件中的内容",
5
+ "把这段录音转成文本",
6
+ "请转录这段语音",
7
+ "将音频内容转换成文字格式",
8
+ "识别并转写这段语音",
9
+ "把语音内容写成文字",
10
+ "转录这个音频片段",
11
+ "将这段对话转换为文本",
12
+ "麻烦帮我把这段录音整理成详细的文字记录",
13
+ ]
14
+
15
+ asr_en_templates = [
16
+ "Please transcribe this audio file",
17
+ "Convert this speech recording to text",
18
+ "Transcribe the following voice message",
19
+ "Turn this audio into readable text",
20
+ "Please convert the recording to written format",
21
+ "Transcribe what you hear in this audio",
22
+ "Convert this spoken content to text",
23
+ "Please write down what is said in this recording",
24
+ "Transcribe this voice recording",
25
+ "Could you please help me transcribe this important recording?",
26
+ "Would you mind converting this voice message into a readable text format?",
27
+ "I'd really appreciate it if you could turn this audio file into a written document",
28
+ ]
29
+
30
+ tts_zh_templates = [
31
+ "请将这段文字转换为语音",
32
+ "帮我把这个文本读出来",
33
+ "将这些文字生成音频",
34
+ "请朗读这段内容",
35
+ "把这段话转换成语音文件",
36
+ "生成这段文字的语音版本",
37
+ "请用语音播报这些内容",
38
+ "将文本转换为可听的音频",
39
+ "帮我朗读这段文字",
40
+ "把这些内容念出来",
41
+ ]
42
+
43
+ tts_en_templates = [
44
+ "Please convert this text to speech",
45
+ "Turn this writing into audio",
46
+ "Generate speech from this text",
47
+ "Read this content out loud",
48
+ "Convert these words to voice",
49
+ "Create an audio version of this text",
50
+ "Please vocalize this content",
51
+ "Turn this text into audible format",
52
+ "Help me convert this writing to speech",
53
+ "Make this text into spoken audio",
54
+ ]
src/mimo_audio_tokenizer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ from .modeling_audio_tokenizer import MiMoAudioTokenizer, StreamingConfig, StreamingCache
3
+ from .configuration_audio_tokenizer import MiMoAudioTokenizerConfig
4
+
5
+
6
+ __all__ = ['MiMoAudioTokenizer', 'StreamingConfig', 'StreamingCache', 'MiMoAudioTokenizerConfig']
src/mimo_audio_tokenizer/configuration_audio_tokenizer.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class MiMoAudioTokenizerConfig(PretrainedConfig):
6
+ model_type = "mimo_audio_tokenizer"
7
+
8
+ def __init__(
9
+ self,
10
+ max_audio_seconds: int = 1800,
11
+ stride_size: int = 2,
12
+ avg_pooler: int = 1,
13
+ d_model: int = 768,
14
+ scale_embedding: bool = True,
15
+ kernel_size: int = 3,
16
+ activation_function: str = "gelu",
17
+ encoder_layers: int = 8,
18
+ encoder_skip_layer_id: int = None,
19
+ encoder_attention_heads: int = 12,
20
+ encoder_ffn_dim: int = 3072,
21
+ encoder_causal: bool = False,
22
+ encoder_attn_window_size: list[int] = None,
23
+ decoder_layers: int = 8,
24
+ decoder_attention_heads: int = 12,
25
+ decoder_ffn_dim: int = 3072,
26
+ decoder_kernel_size: int = 3,
27
+ decoder_stride_size: int = 2,
28
+ decoder_causal: bool = True,
29
+ decoder_attn_window_size: list[int] = None,
30
+ nfft: int = 1024,
31
+ vocoder_dim: int = 512,
32
+ vocoder_intermediate_dim: int = 4096,
33
+ vocoder_num_layers: int = 30,
34
+ n_mels: int = 80,
35
+ sampling_rate: int = 24000,
36
+ hop_length: int = 240,
37
+ window_size: int = 1024,
38
+ vocoder_padding: str = "same",
39
+ fmin: int = 0,
40
+ fmax: int = None,
41
+ num_quantizers: int = 12,
42
+ codebook_size: list[int] = None,
43
+ threshold_ema_dead_code: int = 10,
44
+ position_embedding_type: str = "rope",
45
+ rope_theta: int = 10000,
46
+ rope_type: str = "default",
47
+ ln_type: str = "LayerNorm",
48
+ vocoder_attention_heads: int = 4,
49
+ vocoder_attn_window_size: list[int] = None,
50
+ **kwargs,
51
+ ):
52
+ super().__init__(**kwargs)
53
+ self.max_audio_seconds = max_audio_seconds
54
+ self.stride_size = stride_size
55
+ self.avg_pooler = avg_pooler
56
+ self.d_model = d_model
57
+ self.scale_embedding = scale_embedding
58
+ self.kernel_size = kernel_size
59
+ self.activation_function = activation_function
60
+ self.encoder_layers = encoder_layers
61
+ self.encoder_skip_layer_id = encoder_skip_layer_id
62
+ self.encoder_attention_heads = encoder_attention_heads
63
+ self.encoder_ffn_dim = encoder_ffn_dim
64
+ self.encoder_causal = encoder_causal
65
+ self.encoder_attn_window_size = (
66
+ encoder_attn_window_size
67
+ if encoder_attn_window_size is not None
68
+ else [-1, -1]
69
+ )
70
+ self.decoder_layers = decoder_layers
71
+ self.decoder_attention_heads = decoder_attention_heads
72
+ self.decoder_ffn_dim = decoder_ffn_dim
73
+ self.decoder_kernel_size = decoder_kernel_size
74
+ self.decoder_stride_size = decoder_stride_size
75
+ self.decoder_causal = decoder_causal
76
+ self.decoder_attn_window_size = (
77
+ decoder_attn_window_size
78
+ if decoder_attn_window_size is not None
79
+ else [-1, -1]
80
+ )
81
+ self.nfft = nfft
82
+ self.vocoder_dim = vocoder_dim
83
+ self.vocoder_intermediate_dim = vocoder_intermediate_dim
84
+ self.vocoder_num_layers = vocoder_num_layers
85
+ self.n_mels = n_mels
86
+ self.sampling_rate = sampling_rate
87
+ self.hop_length = hop_length
88
+ self.window_size = window_size
89
+ self.vocoder_padding = vocoder_padding
90
+ self.fmin = fmin
91
+ self.fmax = fmax
92
+ self.num_quantizers = num_quantizers
93
+ self.codebook_size = codebook_size if codebook_size is not None else [1024]
94
+ self.threshold_ema_dead_code = threshold_ema_dead_code
95
+ self.position_embedding_type = position_embedding_type
96
+ self.rope_theta = rope_theta
97
+ self.rope_type = rope_type
98
+ self.ln_type = ln_type
99
+ self.vocoder_attention_heads = vocoder_attention_heads
100
+ self.vocoder_attn_window_size = (
101
+ vocoder_attn_window_size
102
+ if vocoder_attn_window_size is not None
103
+ else [40, 10]
104
+ )
src/mimo_audio_tokenizer/modeling_audio_tokenizer.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from flash_attn import flash_attn_varlen_func
8
+ from torch.nn import functional as F
9
+ from transformers.activations import ACT2FN
10
+ from transformers.modeling_utils import PreTrainedModel
11
+
12
+ from .configuration_audio_tokenizer import MiMoAudioTokenizerConfig
13
+ from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update, apply_rotary_pos_emb
14
+ from .quantization import ResidualVectorQuantizer
15
+ from dataclasses import dataclass, field
16
+ from typing import List
17
+
18
+ def get_sequence_mask(inputs, inputs_length):
19
+ if inputs.dim() == 3:
20
+ bsz, tgt_len, _ = inputs.size()
21
+ else:
22
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
23
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
24
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
25
+ bsz, tgt_len, 1
26
+ )
27
+ unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
28
+ return sequence_mask, unpacking_index
29
+
30
+
31
+ def unpack_hidden_states(
32
+ hidden_states, lengths, sequence_mask=None, unpacking_index=None
33
+ ):
34
+ bsz = lengths.shape[0]
35
+ if sequence_mask is None or unpacking_index is None:
36
+ sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
37
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
38
+ bsz, torch.max(lengths), hidden_states.shape[-1]
39
+ )
40
+ hidden_states = torch.where(sequence_mask, hidden_states, 0)
41
+ return hidden_states
42
+
43
+
44
+ def get_position_ids(lengths):
45
+ total_len = lengths.sum()
46
+ offset = torch.cat([torch.zeros(1).to(lengths), lengths[:-1].cumsum(dim=0)])
47
+ offset = torch.repeat_interleave(offset, lengths)
48
+ position_ids = torch.arange(0, total_len).to(offset) - offset
49
+ return position_ids
50
+
51
+ @dataclass
52
+ class StreamingConfig:
53
+ seg_point: int = field(default=60 * 25)
54
+ process_seg_point: bool = field(default=True)
55
+ left_overlap: int = field(default=10 * 25)
56
+ right_overlap: int = field(default=40)
57
+ seg_point_left_overlap: int = field(default=0)
58
+
59
+ @dataclass
60
+ class StreamingCache:
61
+ hidden_states: List[torch.Tensor] = field(default=None)
62
+ processed_lengths: List[int] = field(default=None)
63
+
64
+ class ISTFT(nn.Module):
65
+ """
66
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
67
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
68
+ See issue: https://github.com/pytorch/pytorch/issues/62323
69
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
70
+ The NOLA constraint is met as we trim padded samples anyway.
71
+
72
+ Args:
73
+ n_fft (int): Size of Fourier transform.
74
+ hop_length (int): The distance between neighboring sliding window frames.
75
+ win_length (int): The size of window frame and STFT filter.
76
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
77
+ """
78
+
79
+ def __init__(
80
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
81
+ ):
82
+ super().__init__()
83
+ if padding not in ["center", "same"]:
84
+ raise ValueError("Padding must be 'center' or 'same'.")
85
+ self.padding = padding
86
+ self.n_fft = n_fft
87
+ self.hop_length = hop_length
88
+ self.win_length = win_length
89
+ window = torch.hann_window(win_length)
90
+ self.register_buffer("window", window)
91
+
92
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
95
+
96
+ Args:
97
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
98
+ N is the number of frequency bins, and T is the number of time frames.
99
+
100
+ Returns:
101
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
102
+ """
103
+ if self.padding == "center":
104
+ # Fallback to pytorch native implementation
105
+ return torch.istft(
106
+ spec,
107
+ self.n_fft,
108
+ self.hop_length,
109
+ self.win_length,
110
+ self.window,
111
+ center=True,
112
+ )
113
+ elif self.padding == "same":
114
+ pad = (self.win_length - self.hop_length) // 2
115
+ else:
116
+ raise ValueError("Padding must be 'center' or 'same'.")
117
+
118
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
119
+ B, N, T = spec.shape
120
+
121
+ # Inverse FFT
122
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
123
+ ifft = ifft * self.window[None, :, None]
124
+
125
+ # Overlap and Add
126
+ output_size = (T - 1) * self.hop_length + self.win_length
127
+ y = torch.nn.functional.fold(
128
+ ifft,
129
+ output_size=(1, output_size),
130
+ kernel_size=(1, self.win_length),
131
+ stride=(1, self.hop_length),
132
+ )[:, 0, 0, pad:-pad]
133
+
134
+ # Window envelope
135
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
136
+ window_envelope = torch.nn.functional.fold(
137
+ window_sq,
138
+ output_size=(1, output_size),
139
+ kernel_size=(1, self.win_length),
140
+ stride=(1, self.hop_length),
141
+ ).squeeze()[pad:-pad]
142
+
143
+ # Normalize
144
+ assert (window_envelope > 1e-11).all()
145
+ y = y / window_envelope
146
+
147
+ return y
148
+
149
+ class ISTFTHead(nn.Module):
150
+ """
151
+ ISTFT Head module for predicting STFT complex coefficients.
152
+
153
+ Args:
154
+ dim (int): Hidden dimension of the model.
155
+ n_fft (int): Size of Fourier transform.
156
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
157
+ the resolution of the input features.
158
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
159
+ """
160
+
161
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
162
+ super().__init__()
163
+ out_dim = n_fft + 2
164
+ self.out = torch.nn.Linear(dim, out_dim)
165
+ self.istft = ISTFT(
166
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
167
+ )
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ """
171
+ Forward pass of the ISTFTHead module.
172
+
173
+ Args:
174
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
175
+ L is the sequence length, and H denotes the model dimension.
176
+
177
+ Returns:
178
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
179
+ """
180
+ x = self.out(x).transpose(1, 2)
181
+ mag, p = x.chunk(2, dim=1)
182
+ mag = torch.exp(mag)
183
+ mag = torch.clip(
184
+ mag, max=1e2
185
+ ) # safeguard to prevent excessively large magnitudes
186
+ # wrapping happens here. These two lines produce real and imaginary value
187
+ x = torch.cos(p)
188
+ y = torch.sin(p)
189
+ # recalculating phase here does not produce anything new
190
+ # only costs time
191
+ # phase = torch.atan2(y, x)
192
+ # S = mag * torch.exp(phase * 1j)
193
+ # better directly produce the complex value
194
+ original_dtype = x.dtype
195
+ S = mag.float() * (x.float() + 1j * y.float())
196
+ audio = self.istft(S)
197
+ audio = audio.to(original_dtype)
198
+ return audio
199
+
200
+
201
+ class RotaryEmbedding(nn.Module):
202
+ def __init__(self, base, dim, max_seq_len, rope_type="default", device=None):
203
+ super().__init__()
204
+ self.max_seq_len = max_seq_len
205
+ self.rope_type = rope_type
206
+
207
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
208
+
209
+ inv_freq, self.attention_scaling = self.rope_init_fn(
210
+ device=device, base=base, dim=dim
211
+ )
212
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
213
+ self.original_inv_freq = self.inv_freq
214
+
215
+ @torch.no_grad()
216
+ @dynamic_rope_update
217
+ def forward(self, x, position_ids):
218
+ inv_freq_expanded = self.inv_freq[:, None].float().expand(-1, 1).to(x.device)
219
+ position_ids_expanded = position_ids[None, :].float()
220
+
221
+ device_type = (
222
+ x.device.type
223
+ if isinstance(x.device.type, str) and x.device.type != "mps"
224
+ else "cpu"
225
+ )
226
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
227
+ freqs = (
228
+ inv_freq_expanded.float() @ position_ids_expanded.float()
229
+ ).transpose(0, 1)
230
+ emb = torch.cat((freqs, freqs), dim=-1)
231
+ cos = emb.cos() * self.attention_scaling
232
+ sin = emb.sin() * self.attention_scaling
233
+
234
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
235
+
236
+ class RMSNorm(nn.Module):
237
+ def __init__(self, hidden_size, eps=1e-6):
238
+ """
239
+ RMSNorm is equivalent to T5LayerNorm
240
+ """
241
+ super().__init__()
242
+ self.weight = nn.Parameter(torch.ones(hidden_size))
243
+ self.variance_epsilon = eps
244
+
245
+ def forward(self, hidden_states):
246
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
247
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
248
+
249
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
250
+ hidden_states = hidden_states.to(self.weight.dtype)
251
+
252
+ return self.weight * hidden_states
253
+
254
+
255
+ LAYER_NORM = {"LayerNorm": nn.LayerNorm, "RMSNorm": RMSNorm}
256
+
257
+
258
+ class Attention(nn.Module):
259
+ def __init__(self, embed_dim, num_heads, window_size=(-1, -1), causal=False):
260
+ super().__init__()
261
+ self.embed_dim = embed_dim
262
+ self.num_heads = num_heads
263
+ self.head_dim = embed_dim // num_heads
264
+ self.window_size = window_size
265
+
266
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
267
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
268
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
269
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
270
+
271
+ self.causal = causal
272
+
273
+ def forward(
274
+ self,
275
+ hidden_states: torch.Tensor,
276
+ seq_len: torch.Tensor,
277
+ rope_position_embeddings=None,
278
+ ):
279
+ bsz, _ = hidden_states.size()
280
+
281
+ query_states = self.q_proj(hidden_states).view(
282
+ bsz, self.num_heads, self.head_dim
283
+ )
284
+ key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
285
+ value_states = self.v_proj(hidden_states).view(
286
+ bsz, self.num_heads, self.head_dim
287
+ )
288
+
289
+ if rope_position_embeddings is not None:
290
+ cos, sin = rope_position_embeddings
291
+ query_states = apply_rotary_pos_emb(query_states, cos, sin)
292
+ key_states = apply_rotary_pos_emb(key_states, cos, sin)
293
+
294
+ cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(
295
+ torch.int32
296
+ )
297
+ max_seqlen = torch.max(seq_len).to(torch.int32).detach()
298
+ attn_output = flash_attn_varlen_func(
299
+ query_states,
300
+ key_states,
301
+ value_states,
302
+ cu_len,
303
+ cu_len,
304
+ max_seqlen,
305
+ max_seqlen,
306
+ causal=self.causal,
307
+ window_size=self.window_size,
308
+ )
309
+ attn_output = attn_output.reshape(bsz, self.embed_dim)
310
+ attn_output = self.out_proj(attn_output)
311
+ return attn_output
312
+
313
+
314
+ class TransformerLayer(nn.Module):
315
+ def __init__(
316
+ self,
317
+ act,
318
+ d_model,
319
+ encoder_attention_heads,
320
+ encoder_ffn_dim,
321
+ causal,
322
+ ln_type="LayerNorm",
323
+ attn_window_size=(-1, -1),
324
+ ):
325
+ super().__init__()
326
+ self.embed_dim = d_model
327
+ self.self_attn = Attention(
328
+ self.embed_dim, encoder_attention_heads, attn_window_size, causal
329
+ )
330
+
331
+ self.self_attn_layer_norm = LAYER_NORM[ln_type](self.embed_dim)
332
+
333
+ self.activation_fn = act
334
+ self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
335
+ self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
336
+
337
+ self.final_layer_norm = LAYER_NORM[ln_type](self.embed_dim)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor,
342
+ seq_len: torch.Tensor,
343
+ rope_position_embeddings: torch.Tensor,
344
+ ) -> torch.Tensor:
345
+ residual = hidden_states
346
+ hidden_states = self.self_attn_layer_norm(hidden_states)
347
+ hidden_states = self.self_attn(
348
+ hidden_states, seq_len, rope_position_embeddings=rope_position_embeddings
349
+ )
350
+ hidden_states = residual + hidden_states
351
+ residual = hidden_states
352
+ hidden_states = self.final_layer_norm(hidden_states)
353
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
354
+ hidden_states = self.fc2(hidden_states)
355
+ hidden_states = residual + hidden_states
356
+
357
+ if (
358
+ hidden_states.dtype == torch.float16
359
+ or hidden_states.dtype == torch.bfloat16
360
+ ) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
361
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
362
+ hidden_states = torch.clamp(
363
+ hidden_states, min=-clamp_value, max=clamp_value
364
+ )
365
+ return hidden_states
366
+
367
+
368
+ class TransformerVocos(nn.Module):
369
+ def __init__(self, config: MiMoAudioTokenizerConfig):
370
+ super().__init__()
371
+ self.config = config
372
+ self.max_source_positions = (
373
+ self.config.max_audio_seconds
374
+ * self.config.sampling_rate
375
+ // self.config.hop_length
376
+ )
377
+ self.embeddings = nn.Linear(config.n_mels, config.vocoder_dim, bias=False)
378
+
379
+ self.poisition_embedding = RotaryEmbedding(
380
+ config.rope_theta,
381
+ config.vocoder_dim // config.vocoder_attention_heads,
382
+ self.max_source_positions,
383
+ self.config.rope_type,
384
+ )
385
+
386
+ self.layers = nn.ModuleList(
387
+ [
388
+ TransformerLayer(
389
+ ACT2FN[self.config.activation_function],
390
+ self.config.vocoder_dim,
391
+ self.config.vocoder_attention_heads,
392
+ self.config.vocoder_intermediate_dim,
393
+ causal=False,
394
+ ln_type=self.config.ln_type,
395
+ attn_window_size=self.config.vocoder_attn_window_size,
396
+ )
397
+ for _ in range(self.config.vocoder_num_layers)
398
+ ]
399
+ )
400
+
401
+ self.layer_norm = LAYER_NORM[self.config.ln_type](self.config.vocoder_dim)
402
+ self.hop_size = self.config.hop_length
403
+ self.head = ISTFTHead(
404
+ self.config.vocoder_dim,
405
+ self.config.nfft,
406
+ self.config.hop_length,
407
+ self.config.vocoder_padding,
408
+ )
409
+
410
+ def forward(self, x: torch.Tensor, input_length):
411
+ x = x.transpose(1, 2)
412
+ attention_mask, unpacking_index = get_sequence_mask(x, input_length)
413
+ x = torch.masked_select(x, attention_mask).view(
414
+ torch.sum(input_length), self.config.n_mels
415
+ )
416
+ x = self.embeddings(x)
417
+ position_ids = torch.arange(0, x.size(0), device=x.device, dtype=torch.long)
418
+ rope_position_embeddings = self.poisition_embedding(x, position_ids)
419
+ for idx, layer in enumerate(self.layers):
420
+ x = layer(
421
+ x, input_length, rope_position_embeddings=rope_position_embeddings
422
+ )
423
+
424
+ x = self.layer_norm(x)
425
+ x = unpack_hidden_states(x, input_length, attention_mask, unpacking_index)
426
+ x = self.head(x)
427
+ output_length = input_length * self.hop_size
428
+ return x[:, None, :], output_length
429
+
430
+
431
+ class AudioEncoder(nn.Module):
432
+ def __init__(self, config: MiMoAudioTokenizerConfig):
433
+ super().__init__()
434
+ config._attn_implementation = "flash_attention_2"
435
+ self.config = config
436
+ self.max_source_positions = (
437
+ config.max_audio_seconds * config.sampling_rate // config.hop_length
438
+ ) // config.stride_size
439
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
440
+
441
+ self.skip_layer_idx = config.encoder_skip_layer_id
442
+ self.conv1 = nn.Conv1d(
443
+ config.n_mels, config.d_model, kernel_size=config.kernel_size, padding=1
444
+ )
445
+ self.conv2 = nn.Conv1d(
446
+ config.d_model,
447
+ config.d_model,
448
+ kernel_size=config.kernel_size,
449
+ stride=config.stride_size,
450
+ padding=1,
451
+ )
452
+
453
+ self.position_embedding = RotaryEmbedding(
454
+ config.rope_theta,
455
+ config.d_model // config.encoder_attention_heads,
456
+ self.max_source_positions,
457
+ config.rope_type,
458
+ )
459
+
460
+ self.layers = nn.ModuleList(
461
+ [
462
+ TransformerLayer(
463
+ ACT2FN[config.activation_function],
464
+ config.d_model,
465
+ config.encoder_attention_heads,
466
+ config.encoder_ffn_dim,
467
+ causal=self.config.encoder_causal,
468
+ ln_type=self.config.ln_type,
469
+ attn_window_size=self.config.encoder_attn_window_size,
470
+ )
471
+ for _ in range(config.encoder_layers)
472
+ ]
473
+ )
474
+
475
+ self.layer_norm = LAYER_NORM[config.ln_type](config.d_model)
476
+
477
+ if self.config.avg_pooler != 1:
478
+ self.down_sample_layer = nn.Sequential(
479
+ nn.Conv1d(
480
+ config.d_model,
481
+ config.d_model,
482
+ config.avg_pooler,
483
+ config.avg_pooler,
484
+ bias=False,
485
+ ),
486
+ nn.GELU(),
487
+ )
488
+ self.down_sample_norm = LAYER_NORM[config.ln_type](config.d_model)
489
+ else:
490
+ self.down_sample_layer = None
491
+
492
+ if self.config.num_quantizers != 0:
493
+ self.quantizer = ResidualVectorQuantizer(
494
+ dimension=self.config.d_model,
495
+ n_q=self.config.num_quantizers,
496
+ bins=self.config.codebook_size,
497
+ threshold_ema_dead_code=self.config.threshold_ema_dead_code,
498
+ )
499
+ else:
500
+ self.quantizer = None
501
+
502
+ def get_features(self, input_features, output_length):
503
+ input_features = input_features.to(self.conv1.weight)
504
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
505
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
506
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
507
+ bsz, tgt_len, _ = inputs_embeds.size()
508
+
509
+ hidden_states = inputs_embeds
510
+
511
+ position_ids = (
512
+ get_position_ids(output_length).long().to(input_features.device)
513
+ )
514
+ rope_position_embeddings = self.position_embedding(
515
+ input_features, position_ids
516
+ )
517
+
518
+ attention_mask, unpacking_index = get_sequence_mask(
519
+ hidden_states, output_length
520
+ )
521
+
522
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
523
+ torch.sum(output_length), self.config.d_model
524
+ )
525
+
526
+ skip_connect_hidden_states = 0.0
527
+ for idx, encoder_layer in enumerate(self.layers):
528
+ hidden_states = encoder_layer(
529
+ hidden_states,
530
+ output_length,
531
+ rope_position_embeddings=rope_position_embeddings,
532
+ )
533
+ if (self.skip_layer_idx is not None) and idx == self.skip_layer_idx - 1:
534
+ skip_connect_hidden_states = hidden_states.clone()
535
+
536
+ hidden_states += skip_connect_hidden_states
537
+ hidden_states = self.layer_norm(hidden_states)
538
+
539
+ if self.down_sample_layer is not None:
540
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
541
+ bsz, tgt_len, self.config.d_model
542
+ )
543
+ if hidden_states.size(1) % self.config.avg_pooler:
544
+ pad_len = (
545
+ self.config.avg_pooler
546
+ - hidden_states.size(1) % self.config.avg_pooler
547
+ )
548
+ hidden_states = torch.nn.functional.pad(
549
+ hidden_states, (0, 0, 0, pad_len), mode="constant", value=0.0
550
+ )
551
+ tgt_len += pad_len
552
+ tgt_len = tgt_len // self.config.avg_pooler
553
+ hidden_states = self.down_sample_layer(hidden_states.transpose(1, 2))
554
+ output_length = (
555
+ output_length // self.config.avg_pooler
556
+ + (output_length % self.config.avg_pooler != 0).int()
557
+ )
558
+ hidden_states = hidden_states.transpose(1, 2)
559
+ attention_mask, unpacking_index = get_sequence_mask(
560
+ hidden_states, output_length
561
+ )
562
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
563
+ torch.sum(output_length), self.config.d_model
564
+ )
565
+ hidden_states = self.down_sample_norm(hidden_states)
566
+
567
+ return (
568
+ hidden_states,
569
+ output_length,
570
+ attention_mask,
571
+ unpacking_index,
572
+ tgt_len,
573
+ bsz,
574
+ )
575
+
576
+ def get_output_length(self, mel_len):
577
+ tgt_len = mel_len + 3 - self.config.kernel_size
578
+ return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
579
+
580
+ @torch.no_grad()
581
+ def encode(
582
+ self,
583
+ input_features,
584
+ input_lens=None,
585
+ output_length=None,
586
+ return_codes_only=False,
587
+ n_q=None,
588
+ use_quantizer=True,
589
+ ):
590
+ if output_length is None:
591
+ output_length = self.get_output_length(input_lens)
592
+ input_features = unpack_hidden_states(input_features, input_lens)
593
+ hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz = (
594
+ self.get_features(
595
+ input_features=input_features.transpose(1, 2),
596
+ output_length=output_length,
597
+ )
598
+ )
599
+
600
+ dtype = hidden_states.dtype
601
+
602
+ if use_quantizer and self.quantizer is not None:
603
+ self.quantizer.float()
604
+
605
+ codes = self.quantizer.encode(hidden_states.float(), n_q=n_q)
606
+ if return_codes_only:
607
+ return codes, output_length
608
+ hidden_states = self.quantizer.decode(codes)
609
+ hidden_states = hidden_states.to(dtype)
610
+ else:
611
+ codes = None
612
+
613
+ hidden_states_packed = hidden_states.clone()
614
+
615
+ # unpacking
616
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
617
+ bsz, tgt_len, self.config.d_model
618
+ )
619
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
620
+ return hidden_states, hidden_states_packed, output_length, codes
621
+
622
+ @torch.no_grad()
623
+ def decode_vq(self, codes):
624
+ self.quantizer.float()
625
+ hidden_states = self.quantizer.decode(codes)
626
+
627
+ return hidden_states
628
+
629
+
630
+ class CausalConvTranspose1d(nn.Module):
631
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
632
+ super().__init__()
633
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
634
+ self.norm = nn.GroupNorm(1, out_channels)
635
+ self.in_channels = in_channels
636
+ self.out_channels = out_channels
637
+
638
+ def forward(self, hidden_states, input_length, output_dim=None):
639
+ kernel_size = self.conv.kernel_size[0]
640
+ stride = self.conv.stride[0]
641
+ bsz = input_length.shape[0]
642
+
643
+ if output_dim is None:
644
+ output_dim = hidden_states.dim()
645
+ if hidden_states.dim() <= 2: # unpack sequence to 3d
646
+ sequence_mask, unpacking_index = get_sequence_mask(
647
+ hidden_states, input_length
648
+ )
649
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
650
+ bsz, torch.max(input_length), self.in_channels
651
+ )
652
+ hidden_states = torch.where(sequence_mask, hidden_states, 0)
653
+
654
+ hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
655
+ hidden_states = self.conv(hidden_states)
656
+ hidden_states = self.norm(hidden_states)
657
+ hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
658
+
659
+ casual_padding_right = max(0, kernel_size - stride)
660
+ hidden_states = hidden_states[
661
+ :, : hidden_states.shape[1] - casual_padding_right, :
662
+ ]
663
+ output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
664
+ sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
665
+ if output_dim <= 2:
666
+ hidden_states = torch.masked_select(hidden_states, sequence_mask).view(
667
+ -1, self.out_channels
668
+ )
669
+ else:
670
+ hidden_states = torch.where(sequence_mask, hidden_states, 0)
671
+ hidden_states = hidden_states[:, : torch.max(output_length), :]
672
+ return hidden_states, output_length
673
+
674
+
675
+ class AudioDecoder(nn.Module):
676
+ def __init__(self, config: MiMoAudioTokenizerConfig):
677
+ super().__init__()
678
+ self.config = config
679
+ self.max_source_positions = (
680
+ self.config.max_audio_seconds
681
+ * self.config.sampling_rate
682
+ // self.config.hop_length
683
+ )
684
+
685
+ if self.config.avg_pooler != 1:
686
+ self.dconv1 = CausalConvTranspose1d(
687
+ self.config.d_model,
688
+ self.config.d_model,
689
+ self.config.avg_pooler,
690
+ self.config.avg_pooler,
691
+ )
692
+ else:
693
+ self.dconv1 = None
694
+
695
+ self.position_embedding = RotaryEmbedding(
696
+ config.rope_theta,
697
+ config.d_model // config.decoder_attention_heads,
698
+ self.max_source_positions,
699
+ config.rope_type,
700
+ )
701
+
702
+ self.layers = nn.ModuleList(
703
+ [
704
+ TransformerLayer(
705
+ ACT2FN[self.config.activation_function],
706
+ self.config.d_model,
707
+ self.config.decoder_attention_heads,
708
+ self.config.decoder_ffn_dim,
709
+ causal=self.config.decoder_causal,
710
+ ln_type=self.config.ln_type,
711
+ attn_window_size=self.config.decoder_attn_window_size,
712
+ )
713
+ for _ in range(self.config.decoder_layers)
714
+ ]
715
+ )
716
+ self.layer_norm = LAYER_NORM[config.ln_type](self.config.d_model)
717
+ self.dconv2 = CausalConvTranspose1d(
718
+ self.config.d_model,
719
+ self.config.n_mels,
720
+ self.config.decoder_kernel_size,
721
+ self.config.decoder_stride_size,
722
+ )
723
+ self.vocoder = TransformerVocos(config)
724
+
725
+ def forward(
726
+ self,
727
+ audio_embed,
728
+ input_length,
729
+ ):
730
+ assert audio_embed.shape[-1] == self.config.d_model
731
+ audio_embed = audio_embed.to(self.layer_norm.weight)
732
+
733
+ if self.dconv1 is not None:
734
+ audio_embed, output_length = self.dconv1(
735
+ audio_embed, input_length, output_dim=3
736
+ )
737
+ _, tgt_len, _ = audio_embed.size()
738
+ else:
739
+ output_length = input_length
740
+ tgt_len = audio_embed.size(0)
741
+
742
+ hidden_states = audio_embed
743
+
744
+ position_ids = (
745
+ get_position_ids(output_length).long().to(hidden_states.device)
746
+ )
747
+ rope_position_embeddings = self.position_embedding(
748
+ hidden_states, position_ids
749
+ )
750
+
751
+
752
+ # packing hidden states
753
+ attention_mask, _ = get_sequence_mask(hidden_states, output_length)
754
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
755
+ torch.sum(output_length), self.config.d_model
756
+ )
757
+
758
+ for idx, encoder_layer in enumerate(self.layers):
759
+ hidden_states = encoder_layer(
760
+ hidden_states,
761
+ output_length,
762
+ rope_position_embeddings=rope_position_embeddings,
763
+ )
764
+
765
+ hidden_states = self.layer_norm(hidden_states)
766
+
767
+ coarse_mel, output_length = self.dconv2(
768
+ hidden_states, output_length, output_dim=3
769
+ )
770
+
771
+ recon_wav, wav_length = self.vocoder(
772
+ x=coarse_mel.transpose(1, 2),
773
+ input_length=output_length,
774
+ )
775
+
776
+ return recon_wav
777
+
778
+
779
+ class MiMoAudioTokenizer(PreTrainedModel):
780
+ config_class = MiMoAudioTokenizerConfig
781
+
782
+ def __init__(self, config: MiMoAudioTokenizerConfig):
783
+ super().__init__(config)
784
+ self.config = config
785
+ self.sampling_rate = config.sampling_rate
786
+ self.encoder = AudioEncoder(config=config)
787
+ self.decoder = AudioDecoder(config=config)
788
+ self.downsample_rate = int(self.config.hop_length * 2 * self.config.avg_pooler)
789
+
790
+ def get_output_length(self, mel_len):
791
+ tgt_len = mel_len + 3 - self.config.kernel_size
792
+ return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
793
+
794
+ @torch.no_grad()
795
+ def encode(self, mels, input_lens, use_quantizer=True):
796
+ input_features = mels
797
+ encoder_output_length = self.get_output_length(input_lens)
798
+ hidden_states, hidden_states_packed, encoder_output_length, codes = (
799
+ self.encoder.encode(
800
+ input_features, input_lens=input_lens, use_quantizer=use_quantizer
801
+ )
802
+ )
803
+ return hidden_states, hidden_states_packed, encoder_output_length, codes
804
+
805
+ @torch.no_grad()
806
+ def decode(self, codes):
807
+ hidden_states = self.encoder.decode_vq(codes)
808
+ output = self.decoder(
809
+ hidden_states,
810
+ torch.tensor([hidden_states.size(0)], device=hidden_states.device),
811
+ )
812
+ return output
813
+
814
+ @torch.no_grad()
815
+ def streaming_decode(self, codes_chunks, chunk_input_lengths, history_cache=StreamingCache(), streaming_config=StreamingConfig(), last_chunk=False):
816
+ hidden_states = self.encoder.decode_vq(codes_chunks)
817
+ input_lengths = []
818
+ input_hidden_states = []
819
+ start_idx = 0
820
+ cache_hidden_states = []
821
+ for i, input_length in enumerate(chunk_input_lengths):
822
+ sample_hidden_states = hidden_states[start_idx:start_idx + input_length]
823
+ start_idx += input_length
824
+ if history_cache.hidden_states is not None:
825
+ sample_hidden_states = torch.cat([history_cache.hidden_states[i], sample_hidden_states], dim=0)
826
+ input_length += history_cache.hidden_states[i].size(0)
827
+ input_hidden_states.append(sample_hidden_states)
828
+ cache_hidden_states.append(sample_hidden_states.clone())
829
+ input_lengths.append(input_length)
830
+ input_hidden_states = torch.cat(input_hidden_states, dim=0)
831
+ input_lengths = torch.tensor(input_lengths, device=hidden_states.device)
832
+ output = self.decoder(input_hidden_states, input_lengths)
833
+ return_wavs = []
834
+ frames_per_token = self.config.avg_pooler * self.config.stride_size * self.config.hop_length
835
+ processed_lengths = []
836
+ for i, wav in enumerate(output):
837
+ wav = wav.float().detach().cpu()
838
+ start_idx = history_cache.processed_lengths[i] if history_cache.processed_lengths is not None else 0
839
+ if last_chunk:
840
+ return_wavs.append(wav[:, start_idx * frames_per_token:])
841
+ new_processed_length = input_lengths[i].item()
842
+ elif input_lengths[i].item() <= streaming_config.right_overlap:
843
+ return_wavs.append(None)
844
+ new_processed_length = 0
845
+ else:
846
+ end_idx = (input_lengths[i].item() - streaming_config.right_overlap)
847
+ wav = wav[:, start_idx * frames_per_token: end_idx * frames_per_token]
848
+ return_wavs.append(wav)
849
+ new_processed_length = end_idx
850
+ if input_lengths[i].item() > streaming_config.left_overlap:
851
+ cache_hidden_states[i] = cache_hidden_states[i][-streaming_config.left_overlap:]
852
+ new_processed_length -= (input_lengths[i].item() - streaming_config.left_overlap)
853
+ processed_lengths.append(new_processed_length)
854
+ history_cache.hidden_states = cache_hidden_states
855
+ history_cache.processed_lengths = processed_lengths
856
+
857
+ return return_wavs, history_cache
src/mimo_audio_tokenizer/modeling_rope_utils.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from functools import wraps
18
+ from typing import Optional
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import is_torch_available, logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ if is_torch_available():
28
+ import torch
29
+
30
+
31
+ def dynamic_rope_update(rope_forward):
32
+ """
33
+ Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
34
+ (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
35
+
36
+ Args:
37
+ rope_forward (Callable):
38
+ The forward pass of the RoPE implementation.
39
+
40
+ Returns:
41
+ The decorated forward pass.
42
+ """
43
+
44
+ def longrope_frequency_update(self, position_ids, device):
45
+ """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
46
+ seq_len = torch.max(position_ids) + 1
47
+ if hasattr(self.config, "original_max_position_embeddings"):
48
+ original_max_position_embeddings = (
49
+ self.config.original_max_position_embeddings
50
+ )
51
+ else:
52
+ original_max_position_embeddings = self.config.max_position_embeddings
53
+ if seq_len > original_max_position_embeddings:
54
+ if not hasattr(self, "long_inv_freq"):
55
+ self.long_inv_freq, _ = self.rope_init_fn(
56
+ self.config, device, seq_len=original_max_position_embeddings + 1
57
+ )
58
+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
59
+ else:
60
+ # This .to() is needed if the model has been moved to a device after being initialized (because
61
+ # the buffer is automatically moved, but not the original copy)
62
+ self.original_inv_freq = self.original_inv_freq.to(device)
63
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
64
+
65
+ def dynamic_frequency_update(self, position_ids, device):
66
+ """
67
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
68
+ 1 - growing beyond the cached sequence length (allow scaling)
69
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
70
+ """
71
+ seq_len = torch.max(position_ids) + 1
72
+ if seq_len > self.max_seq_len_cached: # growth
73
+ inv_freq, self.attention_scaling = self.rope_init_fn(
74
+ self.config, device, seq_len=seq_len
75
+ )
76
+ self.register_buffer(
77
+ "inv_freq", inv_freq, persistent=False
78
+ ) # TODO joao: may break with compilation
79
+ self.max_seq_len_cached = seq_len
80
+
81
+ if (
82
+ seq_len < self.original_max_seq_len
83
+ and self.max_seq_len_cached > self.original_max_seq_len
84
+ ): # reset
85
+ # This .to() is needed if the model has been moved to a device after being initialized (because
86
+ # the buffer is automatically moved, but not the original copy)
87
+ self.original_inv_freq = self.original_inv_freq.to(device)
88
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
89
+ self.max_seq_len_cached = self.original_max_seq_len
90
+
91
+ @wraps(rope_forward)
92
+ def wrapper(self, x, position_ids):
93
+ if "dynamic" in self.rope_type:
94
+ dynamic_frequency_update(self, position_ids, device=x.device)
95
+ elif self.rope_type == "longrope":
96
+ longrope_frequency_update(self, position_ids, device=x.device)
97
+ return rope_forward(self, x, position_ids)
98
+
99
+ return wrapper
100
+
101
+
102
+ def _compute_default_rope_parameters(
103
+ config: Optional[PretrainedConfig] = None,
104
+ device: Optional["torch.device"] = None,
105
+ seq_len: Optional[int] = None,
106
+ **rope_kwargs,
107
+ ) -> tuple["torch.Tensor", float]:
108
+ """
109
+ Computes the inverse frequencies according to the original RoPE implementation
110
+ Args:
111
+ config ([`~transformers.PretrainedConfig`]):
112
+ The model configuration.
113
+ device (`torch.device`):
114
+ The device to use for initialization of the inverse frequencies.
115
+ seq_len (`int`, *optional*):
116
+ The current sequence length. Unused for this type of RoPE.
117
+ rope_kwargs (`Dict`, *optional*):
118
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
119
+ Returns:
120
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
121
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
122
+ """
123
+ if config is not None and len(rope_kwargs) > 0:
124
+ raise ValueError(
125
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
126
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
127
+ )
128
+ if len(rope_kwargs) > 0:
129
+ base = rope_kwargs["base"]
130
+ dim = rope_kwargs["dim"]
131
+ elif config is not None:
132
+ base = config.rope_theta
133
+ partial_rotary_factor = (
134
+ config.partial_rotary_factor
135
+ if hasattr(config, "partial_rotary_factor")
136
+ else 1.0
137
+ )
138
+ head_dim = (
139
+ getattr(config, "head_dim", None)
140
+ or config.hidden_size // config.num_attention_heads
141
+ )
142
+ dim = int(head_dim * partial_rotary_factor)
143
+
144
+ attention_factor = 1.0 # Unused in this type of RoPE
145
+
146
+ # Compute the inverse frequencies
147
+ inv_freq = 1.0 / (
148
+ base
149
+ ** (
150
+ torch.arange(0, dim, 2, dtype=torch.int64).to(
151
+ device=device, dtype=torch.float
152
+ )
153
+ / dim
154
+ )
155
+ )
156
+ return inv_freq, attention_factor
157
+
158
+
159
+ def _compute_linear_scaling_rope_parameters(
160
+ config: Optional[PretrainedConfig] = None,
161
+ device: Optional["torch.device"] = None,
162
+ seq_len: Optional[int] = None,
163
+ **rope_kwargs,
164
+ ) -> tuple["torch.Tensor", float]:
165
+ """
166
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
167
+ Args:
168
+ config ([`~transformers.PretrainedConfig`]):
169
+ The model configuration.
170
+ device (`torch.device`):
171
+ The device to use for initialization of the inverse frequencies.
172
+ seq_len (`int`, *optional*):
173
+ The current sequence length. Unused for this type of RoPE.
174
+ rope_kwargs (`Dict`, *optional*):
175
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
176
+ Returns:
177
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
178
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
179
+ """
180
+ if config is not None and len(rope_kwargs) > 0:
181
+ raise ValueError(
182
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
183
+ f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
184
+ )
185
+ if len(rope_kwargs) > 0:
186
+ factor = rope_kwargs["factor"]
187
+ elif config is not None:
188
+ factor = config.rope_scaling["factor"]
189
+
190
+ # Gets the default RoPE parameters
191
+ inv_freq, attention_factor = _compute_default_rope_parameters(
192
+ config, device, seq_len, **rope_kwargs
193
+ )
194
+
195
+ # Then applies linear scaling to the frequencies.
196
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
197
+ # applying scaling to the inverse frequencies is equivalent.
198
+ inv_freq /= factor
199
+ return inv_freq, attention_factor
200
+
201
+
202
+ def _compute_dynamic_ntk_parameters(
203
+ config: Optional[PretrainedConfig] = None,
204
+ device: Optional["torch.device"] = None,
205
+ seq_len: Optional[int] = None,
206
+ **rope_kwargs,
207
+ ) -> tuple["torch.Tensor", float]:
208
+ """
209
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
210
+ Args:
211
+ config ([`~transformers.PretrainedConfig`]):
212
+ The model configuration.
213
+ device (`torch.device`):
214
+ The device to use for initialization of the inverse frequencies.
215
+ seq_len (`int`, *optional*):
216
+ The current sequence length, used to update the dynamic RoPE at inference time.
217
+ rope_kwargs (`Dict`, *optional*):
218
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
219
+ Returns:
220
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
221
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
222
+ """
223
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
224
+ if config is not None and len(rope_kwargs) > 0:
225
+ raise ValueError(
226
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
227
+ f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
228
+ )
229
+ if len(rope_kwargs) > 0:
230
+ base = rope_kwargs["base"]
231
+ dim = rope_kwargs["dim"]
232
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
233
+ factor = rope_kwargs["factor"]
234
+ elif config is not None:
235
+ base = config.rope_theta
236
+ partial_rotary_factor = (
237
+ config.partial_rotary_factor
238
+ if hasattr(config, "partial_rotary_factor")
239
+ else 1.0
240
+ )
241
+ head_dim = getattr(
242
+ config, "head_dim", config.hidden_size // config.num_attention_heads
243
+ )
244
+ dim = int(head_dim * partial_rotary_factor)
245
+ max_position_embeddings = config.max_position_embeddings
246
+ factor = config.rope_scaling["factor"]
247
+
248
+ attention_factor = 1.0 # Unused in this type of RoPE
249
+
250
+ # seq_len: default to max_position_embeddings, e.g. at init time
251
+ seq_len = (
252
+ seq_len
253
+ if seq_len is not None and seq_len > max_position_embeddings
254
+ else max_position_embeddings
255
+ )
256
+
257
+ # Compute the inverse frequencies
258
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
259
+ dim / (dim - 2)
260
+ )
261
+ inv_freq = 1.0 / (
262
+ base
263
+ ** (
264
+ torch.arange(0, dim, 2, dtype=torch.int64).to(
265
+ device=device, dtype=torch.float
266
+ )
267
+ / dim
268
+ )
269
+ )
270
+ return inv_freq, attention_factor
271
+
272
+
273
+ def _compute_yarn_parameters(
274
+ config: PretrainedConfig,
275
+ device: "torch.device",
276
+ seq_len: Optional[int] = None,
277
+ **rope_kwargs,
278
+ ) -> tuple["torch.Tensor", float]:
279
+ """
280
+ Computes the inverse frequencies with NTK scaling. Please refer to the
281
+ [original paper](https://huggingface.co/papers/2309.00071)
282
+ Args:
283
+ config ([`~transformers.PretrainedConfig`]):
284
+ The model configuration.
285
+ device (`torch.device`):
286
+ The device to use for initialization of the inverse frequencies.
287
+ seq_len (`int`, *optional*):
288
+ The current sequence length. Unused for this type of RoPE.
289
+ rope_kwargs (`Dict`, *optional*):
290
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
291
+ Returns:
292
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
293
+ post-processing scaling factor applied to the computed cos/sin.
294
+ """
295
+ # No need to keep BC with yarn, unreleased when this new pattern was created.
296
+ if len(rope_kwargs) > 0:
297
+ raise ValueError(
298
+ f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
299
+ )
300
+
301
+ base = config.rope_theta
302
+ partial_rotary_factor = (
303
+ config.partial_rotary_factor
304
+ if hasattr(config, "partial_rotary_factor")
305
+ else 1.0
306
+ )
307
+ head_dim = getattr(
308
+ config, "head_dim", config.hidden_size // config.num_attention_heads
309
+ )
310
+ dim = int(head_dim * partial_rotary_factor)
311
+ factor = config.rope_scaling["factor"]
312
+ attention_factor = config.rope_scaling.get("attention_factor")
313
+ mscale = config.rope_scaling.get("mscale")
314
+ mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
315
+
316
+ # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
317
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
318
+ # values to compute the default attention scaling factor, instead of using `factor`.
319
+ if "original_max_position_embeddings" in config.rope_scaling:
320
+ original_max_position_embeddings = config.rope_scaling[
321
+ "original_max_position_embeddings"
322
+ ]
323
+ factor = config.max_position_embeddings / original_max_position_embeddings
324
+ else:
325
+ original_max_position_embeddings = config.max_position_embeddings
326
+
327
+ def get_mscale(scale, mscale=1):
328
+ if scale <= 1:
329
+ return 1.0
330
+ return 0.1 * mscale * math.log(scale) + 1.0
331
+
332
+ # Sets the attention factor as suggested in the paper
333
+ if attention_factor is None:
334
+ if mscale and mscale_all_dim:
335
+ attention_factor = float(
336
+ get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)
337
+ )
338
+ else:
339
+ attention_factor = get_mscale(factor)
340
+
341
+ # Optional config options
342
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
343
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
344
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
345
+
346
+ # Compute the inverse frequencies
347
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
348
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
349
+ return (
350
+ dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
351
+ ) / (2 * math.log(base))
352
+
353
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
354
+ """Find dimension range bounds based on rotations"""
355
+ low = math.floor(
356
+ find_correction_dim(low_rot, dim, base, max_position_embeddings)
357
+ )
358
+ high = math.ceil(
359
+ find_correction_dim(high_rot, dim, base, max_position_embeddings)
360
+ )
361
+ return max(low, 0), min(high, dim - 1)
362
+
363
+ def linear_ramp_factor(min, max, dim):
364
+ if min == max:
365
+ max += 0.001 # Prevent singularity
366
+
367
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
368
+ ramp_func = torch.clamp(linear_func, 0, 1)
369
+ return ramp_func
370
+
371
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
372
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
373
+ pos_freqs = base ** (
374
+ torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim
375
+ )
376
+ inv_freq_extrapolation = 1.0 / pos_freqs
377
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
378
+
379
+ low, high = find_correction_range(
380
+ beta_fast, beta_slow, dim, base, original_max_position_embeddings
381
+ )
382
+
383
+ # Get n-dimensional rotational scaling corrected for extrapolation
384
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(
385
+ device=device, dtype=torch.float
386
+ )
387
+ inv_freq = (
388
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
389
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
390
+ )
391
+ return inv_freq, attention_factor
392
+
393
+
394
+ def _compute_longrope_parameters(
395
+ config: PretrainedConfig,
396
+ device: "torch.device",
397
+ seq_len: Optional[int] = None,
398
+ **rope_kwargs,
399
+ ) -> tuple["torch.Tensor", float]:
400
+ """
401
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
402
+ [original implementation](https://github.com/microsoft/LongRoPE)
403
+ Args:
404
+ config ([`~transformers.PretrainedConfig`]):
405
+ The model configuration.
406
+ device (`torch.device`):
407
+ The device to use for initialization of the inverse frequencies.
408
+ seq_len (`int`, *optional*):
409
+ The current sequence length.
410
+ rope_kwargs (`Dict`, *optional*):
411
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
412
+ Returns:
413
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
414
+ post-processing scaling factor applied to the computed cos/sin.
415
+ """
416
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
417
+ # No need to keep BC with longrope, unreleased when this new pattern was created.
418
+ if len(rope_kwargs) > 0:
419
+ raise ValueError(
420
+ "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
421
+ f"{rope_kwargs}"
422
+ )
423
+
424
+ base = config.rope_theta
425
+ partial_rotary_factor = (
426
+ config.partial_rotary_factor
427
+ if hasattr(config, "partial_rotary_factor")
428
+ else 1.0
429
+ )
430
+ head_dim = getattr(
431
+ config, "head_dim", config.hidden_size // config.num_attention_heads
432
+ )
433
+ dim = int(head_dim * partial_rotary_factor)
434
+ long_factor = config.rope_scaling["long_factor"]
435
+ short_factor = config.rope_scaling["short_factor"]
436
+ factor = config.rope_scaling.get("factor")
437
+ attention_factor = config.rope_scaling.get("attention_factor")
438
+
439
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
440
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
441
+ # values to compute the default attention scaling factor, instead of using `factor`.
442
+ if hasattr(config, "original_max_position_embeddings"):
443
+ original_max_position_embeddings = config.original_max_position_embeddings
444
+ factor = (
445
+ config.max_position_embeddings / config.original_max_position_embeddings
446
+ )
447
+ else:
448
+ original_max_position_embeddings = config.max_position_embeddings
449
+
450
+ # Sets the attention factor as suggested in the paper
451
+ if attention_factor is None:
452
+ if factor <= 1.0:
453
+ attention_factor = 1.0
454
+ else:
455
+ attention_factor = math.sqrt(
456
+ 1 + math.log(factor) / math.log(original_max_position_embeddings)
457
+ )
458
+
459
+ # Compute the inverse frequencies -- scaled based on the target sequence length
460
+ if seq_len and seq_len > original_max_position_embeddings:
461
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
462
+ else:
463
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
464
+ inv_freq_shape = (
465
+ torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
466
+ )
467
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
468
+
469
+ return inv_freq, attention_factor
470
+
471
+
472
+ def _compute_llama3_parameters(
473
+ config: PretrainedConfig,
474
+ device: "torch.device",
475
+ seq_len: Optional[int] = None,
476
+ **rope_kwargs,
477
+ ) -> tuple["torch.Tensor", float]:
478
+ """
479
+ Computes the inverse frequencies for llama 3.1.
480
+
481
+ Args:
482
+ config ([`~transformers.PretrainedConfig`]):
483
+ The model configuration.
484
+ device (`torch.device`):
485
+ The device to use for initialization of the inverse frequencies.
486
+ seq_len (`int`, *optional*):
487
+ The current sequence length. Unused for this type of RoPE.
488
+ rope_kwargs (`Dict`, *optional*):
489
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
490
+ Returns:
491
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
492
+ post-processing scaling factor applied to the computed cos/sin.
493
+ """
494
+ # Gets the default RoPE parameters
495
+ inv_freq, attention_factor = _compute_default_rope_parameters(
496
+ config, device, seq_len, **rope_kwargs
497
+ )
498
+
499
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
500
+ low_freq_factor = config.rope_scaling[
501
+ "low_freq_factor"
502
+ ] # `1` in the original implementation
503
+ high_freq_factor = config.rope_scaling[
504
+ "high_freq_factor"
505
+ ] # `4` in the original implementation
506
+ old_context_len = config.rope_scaling[
507
+ "original_max_position_embeddings"
508
+ ] # `8192` in the original implementation
509
+
510
+ low_freq_wavelen = old_context_len / low_freq_factor
511
+ high_freq_wavelen = old_context_len / high_freq_factor
512
+
513
+ wavelen = 2 * math.pi / inv_freq
514
+ # wavelen < high_freq_wavelen: do nothing
515
+ # wavelen > low_freq_wavelen: divide by factor
516
+ inv_freq_llama = torch.where(
517
+ wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
518
+ )
519
+ # otherwise: interpolate between the two, using a smooth factor
520
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
521
+ high_freq_factor - low_freq_factor
522
+ )
523
+ smoothed_inv_freq = (
524
+ 1 - smooth_factor
525
+ ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
526
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
527
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
528
+
529
+ return inv_freq_llama, attention_factor
530
+
531
+
532
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
533
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
534
+ # parameterizations, as long as the callable has the same signature.
535
+ ROPE_INIT_FUNCTIONS = {
536
+ "default": _compute_default_rope_parameters,
537
+ "linear": _compute_linear_scaling_rope_parameters,
538
+ "dynamic": _compute_dynamic_ntk_parameters,
539
+ "yarn": _compute_yarn_parameters,
540
+ "longrope": _compute_longrope_parameters,
541
+ "llama3": _compute_llama3_parameters,
542
+ }
543
+
544
+
545
+ def _check_received_keys(
546
+ rope_type: str,
547
+ received_keys: set,
548
+ required_keys: set,
549
+ optional_keys: Optional[set] = None,
550
+ ignore_keys: Optional[set] = None,
551
+ ):
552
+ """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
553
+ # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
554
+ if "type" in received_keys:
555
+ received_keys -= {"type"}
556
+ required_keys.add("rope_type")
557
+
558
+ # Some models need to store model-specific keys, and we don't want to throw warning at them
559
+ if ignore_keys is not None:
560
+ received_keys -= ignore_keys
561
+
562
+ missing_keys = required_keys - received_keys
563
+ if missing_keys:
564
+ raise KeyError(
565
+ f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}"
566
+ )
567
+
568
+ if optional_keys is not None:
569
+ unused_keys = received_keys - required_keys - optional_keys
570
+ else:
571
+ unused_keys = received_keys - required_keys
572
+ if unused_keys:
573
+ logger.warning(
574
+ f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}"
575
+ )
576
+
577
+
578
+ def _validate_default_rope_parameters(
579
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
580
+ ):
581
+ rope_scaling = config.rope_scaling
582
+ rope_type = rope_scaling.get(
583
+ "rope_type", rope_scaling.get("type", None)
584
+ ) # BC: "rope_type" was originally "type"
585
+ required_keys = {"rope_type"}
586
+ received_keys = set(rope_scaling.keys())
587
+ _check_received_keys(
588
+ rope_type, received_keys, required_keys, ignore_keys=ignore_keys
589
+ )
590
+
591
+
592
+ def _validate_linear_scaling_rope_parameters(
593
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
594
+ ):
595
+ rope_scaling = config.rope_scaling
596
+ rope_type = rope_scaling.get(
597
+ "rope_type", rope_scaling.get("type", None)
598
+ ) # BC: "rope_type" was originally "type"
599
+ required_keys = {"rope_type", "factor"}
600
+ received_keys = set(rope_scaling.keys())
601
+ _check_received_keys(
602
+ rope_type, received_keys, required_keys, ignore_keys=ignore_keys
603
+ )
604
+
605
+ factor = rope_scaling["factor"]
606
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
607
+ logger.warning(
608
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
609
+ )
610
+
611
+
612
+ def _validate_dynamic_scaling_rope_parameters(
613
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
614
+ ):
615
+ rope_scaling = config.rope_scaling
616
+ rope_type = rope_scaling.get(
617
+ "rope_type", rope_scaling.get("type", None)
618
+ ) # BC: "rope_type" was originally "type"
619
+ required_keys = {"rope_type", "factor"}
620
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
621
+ optional_keys = {"original_max_position_embeddings"}
622
+ received_keys = set(rope_scaling.keys())
623
+ _check_received_keys(
624
+ rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys
625
+ )
626
+
627
+ factor = rope_scaling["factor"]
628
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
629
+ logger.warning(
630
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
631
+ )
632
+
633
+
634
+ def _validate_yarn_parameters(
635
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
636
+ ):
637
+ rope_scaling = config.rope_scaling
638
+ rope_type = rope_scaling.get(
639
+ "rope_type", rope_scaling.get("type", None)
640
+ ) # BC: "rope_type" was originally "type"
641
+ required_keys = {"rope_type", "factor"}
642
+ optional_keys = {
643
+ "attention_factor",
644
+ "beta_fast",
645
+ "beta_slow",
646
+ "original_max_position_embeddings",
647
+ "mscale",
648
+ "mscale_all_dim",
649
+ }
650
+ received_keys = set(rope_scaling.keys())
651
+ _check_received_keys(
652
+ rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys
653
+ )
654
+
655
+ factor = rope_scaling["factor"]
656
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
657
+ logger.warning(
658
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
659
+ )
660
+
661
+ attention_factor = rope_scaling.get("attention_factor")
662
+ if attention_factor is not None and (
663
+ not isinstance(attention_factor, float) or attention_factor < 0
664
+ ):
665
+ logger.warning(
666
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
667
+ )
668
+ beta_fast = rope_scaling.get("beta_fast")
669
+ if beta_fast is not None and not isinstance(beta_fast, float):
670
+ logger.warning(
671
+ f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}"
672
+ )
673
+ beta_slow = rope_scaling.get("beta_slow")
674
+ if beta_slow is not None and not isinstance(beta_slow, float):
675
+ logger.warning(
676
+ f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}"
677
+ )
678
+
679
+ if (beta_fast or 32) < (beta_slow or 1):
680
+ logger.warning(
681
+ f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
682
+ f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
683
+ )
684
+
685
+
686
+ def _validate_longrope_parameters(
687
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
688
+ ):
689
+ rope_scaling = config.rope_scaling
690
+ rope_type = rope_scaling.get(
691
+ "rope_type", rope_scaling.get("type", None)
692
+ ) # BC: "rope_type" was originally "type"
693
+ required_keys = {"rope_type", "short_factor", "long_factor"}
694
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
695
+ optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
696
+ received_keys = set(rope_scaling.keys())
697
+ _check_received_keys(
698
+ rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys
699
+ )
700
+
701
+ partial_rotary_factor = (
702
+ config.partial_rotary_factor
703
+ if hasattr(config, "partial_rotary_factor")
704
+ else 1.0
705
+ )
706
+ head_dim = getattr(
707
+ config, "head_dim", config.hidden_size // config.num_attention_heads
708
+ )
709
+ dim = int(head_dim * partial_rotary_factor)
710
+
711
+ short_factor = rope_scaling.get("short_factor")
712
+ if not isinstance(short_factor, list) and all(
713
+ isinstance(x, (int, float)) for x in short_factor
714
+ ):
715
+ logger.warning(
716
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}"
717
+ )
718
+ if not len(short_factor) == dim // 2:
719
+ logger.warning(
720
+ f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}"
721
+ )
722
+
723
+ long_factor = rope_scaling.get("long_factor")
724
+ if not isinstance(long_factor, list) and all(
725
+ isinstance(x, (int, float)) for x in long_factor
726
+ ):
727
+ logger.warning(
728
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}"
729
+ )
730
+ if not len(long_factor) == dim // 2:
731
+ logger.warning(
732
+ f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}"
733
+ )
734
+
735
+ # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
736
+ # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
737
+ # unique to longrope (= undesirable)
738
+ if hasattr(config, "original_max_position_embeddings"):
739
+ logger.warning_once(
740
+ "This model has set a `original_max_position_embeddings` field, to be used together with "
741
+ "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
742
+ "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
743
+ "as it is compatible with most model architectures."
744
+ )
745
+ else:
746
+ factor = rope_scaling.get("factor")
747
+ if factor is None:
748
+ logger.warning("Missing required keys in `rope_scaling`: 'factor'")
749
+ elif not isinstance(factor, float) or factor < 1.0:
750
+ logger.warning(
751
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
752
+ )
753
+
754
+ attention_factor = rope_scaling.get("attention_factor")
755
+ if attention_factor is not None:
756
+ if not isinstance(attention_factor, float) or attention_factor < 0.0:
757
+ logger.warning(
758
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
759
+ )
760
+
761
+
762
+ def _validate_llama3_parameters(
763
+ config: PretrainedConfig, ignore_keys: Optional[set] = None
764
+ ):
765
+ rope_scaling = config.rope_scaling
766
+ rope_type = rope_scaling.get(
767
+ "rope_type", rope_scaling.get("type", None)
768
+ ) # BC: "rope_type" was originally "type"
769
+ required_keys = {
770
+ "rope_type",
771
+ "factor",
772
+ "original_max_position_embeddings",
773
+ "low_freq_factor",
774
+ "high_freq_factor",
775
+ }
776
+ received_keys = set(rope_scaling.keys())
777
+ _check_received_keys(
778
+ rope_type, received_keys, required_keys, ignore_keys=ignore_keys
779
+ )
780
+
781
+ factor = rope_scaling["factor"]
782
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
783
+ logger.warning(
784
+ f"`rope_scaling`'s factor field must be a float >= 1, got {factor}"
785
+ )
786
+
787
+ low_freq_factor = rope_scaling["low_freq_factor"]
788
+ high_freq_factor = rope_scaling["high_freq_factor"]
789
+ if low_freq_factor is None or not isinstance(low_freq_factor, float):
790
+ logger.warning(
791
+ f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}"
792
+ )
793
+ if high_freq_factor is None or not isinstance(high_freq_factor, float):
794
+ logger.warning(
795
+ f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}"
796
+ )
797
+ if high_freq_factor <= low_freq_factor:
798
+ logger.warning(
799
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
800
+ f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
801
+ )
802
+
803
+ original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
804
+ if original_max_position_embeddings is None or not isinstance(
805
+ original_max_position_embeddings, int
806
+ ):
807
+ logger.warning(
808
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
809
+ f"{original_max_position_embeddings}"
810
+ )
811
+ if original_max_position_embeddings >= config.max_position_embeddings:
812
+ logger.warning(
813
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
814
+ f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
815
+ )
816
+
817
+
818
+ # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
819
+ ROPE_VALIDATION_FUNCTIONS = {
820
+ "default": _validate_default_rope_parameters,
821
+ "linear": _validate_linear_scaling_rope_parameters,
822
+ "dynamic": _validate_dynamic_scaling_rope_parameters,
823
+ "yarn": _validate_yarn_parameters,
824
+ "longrope": _validate_longrope_parameters,
825
+ "llama3": _validate_llama3_parameters,
826
+ }
827
+
828
+
829
+ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
830
+ """
831
+ Validate the RoPE config arguments, given a `PretrainedConfig` object
832
+ """
833
+ rope_scaling = getattr(
834
+ config, "rope_scaling", None
835
+ ) # not a default parameter in `PretrainedConfig`
836
+ if rope_scaling is None:
837
+ return
838
+
839
+ # BC: "rope_type" was originally "type"
840
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
841
+ validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
842
+ if validation_fn is not None:
843
+ validation_fn(config, ignore_keys=ignore_keys)
844
+ else:
845
+ logger.warning(
846
+ f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
847
+ )
848
+
849
+
850
+ def rotate_half(x):
851
+ """Rotates half the hidden dims of the input."""
852
+ x1 = x[..., : x.shape[-1] // 2]
853
+ x2 = x[..., x.shape[-1] // 2 :]
854
+ return torch.cat((-x2, x1), dim=-1)
855
+
856
+ def apply_rotary_pos_emb(x, cos, sin, position_ids=None, unsqueeze_dim=1):
857
+ """Applies Rotary Position Embedding to the query and key tensors.
858
+
859
+ Args:
860
+ x (`torch.Tensor`): The input tensor.
861
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
862
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
863
+ position_ids (`torch.Tensor`, *optional*):
864
+ Deprecated and unused.
865
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
866
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
867
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
868
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
869
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
870
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
871
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
872
+ Returns:
873
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
874
+ """
875
+ cos = cos.unsqueeze(unsqueeze_dim)
876
+ sin = sin.unsqueeze(unsqueeze_dim)
877
+ x_embed = (x * cos) + (rotate_half(x) * sin)
878
+ return x_embed
src/mimo_audio_tokenizer/quantization.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corporation.
2
+ import typing as tp
3
+
4
+ from einops import rearrange, repeat
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ import torch.distributed as dist
10
+
11
+
12
+ def rank():
13
+ if dist.is_initialized():
14
+ return dist.get_rank()
15
+ else:
16
+ return 0
17
+
18
+
19
+ def world_size():
20
+ if dist.is_initialized():
21
+ return dist.get_world_size()
22
+ else:
23
+ return 1
24
+
25
+
26
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
27
+ return val if val is not None else d
28
+
29
+
30
+ def ema_inplace(moving_avg, new, decay: float):
31
+ if dist.is_initialized():
32
+ dist.all_reduce(new, op=dist.ReduceOp.SUM)
33
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
34
+
35
+
36
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
37
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
38
+
39
+
40
+ def uniform_init(*shape: int):
41
+ t = torch.empty(shape)
42
+ nn.init.kaiming_uniform_(t)
43
+ return t
44
+
45
+
46
+ def sample_vectors(samples, num: int):
47
+ num_samples, device = samples.shape[0], samples.device
48
+
49
+ if num_samples >= num:
50
+ indices = torch.randperm(num_samples, device=device)[:num]
51
+ else:
52
+ indices = torch.randint(0, num_samples, (num,), device=device)
53
+
54
+ selected_samples = samples[indices]
55
+
56
+ if dist.is_initialized():
57
+
58
+ dist.broadcast(selected_samples, src=0)
59
+
60
+ return selected_samples
61
+
62
+
63
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
64
+ dim, dtype = samples.shape[-1], samples.dtype
65
+
66
+ means = sample_vectors(samples, num_clusters)
67
+
68
+ for _ in range(num_iters):
69
+ dists = -(
70
+ samples.pow(2).sum(1, keepdim=True)
71
+ - 2 * samples @ means.t()
72
+ + means.t().pow(2).sum(0, keepdim=True)
73
+ )
74
+
75
+ buckets = dists.max(dim=-1).indices
76
+ bins = torch.bincount(buckets, minlength=num_clusters)
77
+
78
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
79
+ new_means = new_means.scatter_add_(
80
+ 0, repeat(buckets, "n -> n d", d=dim), samples
81
+ )
82
+
83
+ if dist.is_initialized():
84
+ dist.all_reduce(bins, op=dist.ReduceOp.SUM)
85
+ dist.all_reduce(new_means, op=dist.ReduceOp.SUM)
86
+
87
+ zero_mask = bins == 0
88
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
89
+
90
+ new_means = new_means / bins_min_clamped[..., None]
91
+
92
+ means = torch.where(zero_mask[..., None], means, new_means)
93
+
94
+ return means, bins
95
+
96
+
97
+ class EuclideanCodebook(nn.Module):
98
+ """Codebook with Euclidean distance.
99
+ Args:
100
+ dim (int): Dimension.
101
+ codebook_size (int): Codebook size.
102
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
103
+ If set to true, run the k-means algorithm on the first training batch and use
104
+ the learned centroids as initialization.
105
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
106
+ decay (float): Decay for exponential moving average over the codebooks.
107
+ epsilon (float): Epsilon value for numerical stability.
108
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
109
+ that have an exponential moving average cluster size less than the specified threshold with
110
+ randomly selected vector from the current batch.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ dim: int,
116
+ codebook_size: int,
117
+ kmeans_init: int = False,
118
+ kmeans_iters: int = 10,
119
+ decay: float = 0.99,
120
+ epsilon: float = 1e-5,
121
+ threshold_ema_dead_code: int = 2,
122
+ ):
123
+ super().__init__()
124
+ self.decay = decay
125
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
126
+ uniform_init if not kmeans_init else torch.zeros
127
+ )
128
+ embed = init_fn(codebook_size, dim)
129
+
130
+ self.codebook_size = codebook_size
131
+
132
+ self.kmeans_iters = kmeans_iters
133
+ self.epsilon = epsilon
134
+ self.threshold_ema_dead_code = threshold_ema_dead_code
135
+
136
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
137
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
138
+ self.register_buffer("embed", embed)
139
+ self.register_buffer("embed_avg", embed.clone())
140
+
141
+ @torch.jit.ignore
142
+ def init_embed_(self, data):
143
+ if self.inited:
144
+ return
145
+
146
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
147
+ self.embed.data.copy_(embed)
148
+ self.embed_avg.data.copy_(embed.clone())
149
+ self.cluster_size.data.copy_(cluster_size)
150
+ self.inited.data.copy_(torch.Tensor([True]))
151
+
152
+ def replace_(self, samples, mask):
153
+ # modified_codebook = torch.where(
154
+ # mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
155
+ # )
156
+ replace_num = mask.sum()
157
+ modified_codebook = self.embed.clone()
158
+ modified_codebook[mask] = sample_vectors(samples, replace_num)
159
+ self.embed.data.copy_(modified_codebook)
160
+
161
+ def expire_codes_(self, batch_samples):
162
+ if self.threshold_ema_dead_code == 0:
163
+ return
164
+
165
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
166
+ if not torch.any(expired_codes):
167
+ return
168
+
169
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
170
+ self.replace_(batch_samples, mask=expired_codes)
171
+
172
+ def preprocess(self, x):
173
+ x = rearrange(x, "... d -> (...) d")
174
+ return x
175
+
176
+ def quantize(self, x):
177
+ embed = self.embed.t()
178
+ dist = -(
179
+ x.pow(2).sum(1, keepdim=True)
180
+ - 2 * x @ embed
181
+ + embed.pow(2).sum(0, keepdim=True)
182
+ )
183
+ embed_ind = dist.max(dim=-1).indices
184
+ return embed_ind
185
+
186
+ def postprocess_emb(self, embed_ind, shape):
187
+ return embed_ind.view(*shape[:-1])
188
+
189
+ def dequantize(self, embed_ind):
190
+ quantize = F.embedding(embed_ind, self.embed)
191
+ return quantize
192
+
193
+ def encode(self, x):
194
+ shape = x.shape
195
+ # pre-process
196
+ x = self.preprocess(x)
197
+ # quantize
198
+ embed_ind = self.quantize(x)
199
+ # post-process
200
+ embed_ind = self.postprocess_emb(embed_ind, shape)
201
+ return embed_ind
202
+
203
+ def decode(self, embed_ind):
204
+ quantize = self.dequantize(embed_ind)
205
+ return quantize
206
+
207
+ def forward(self, x):
208
+ shape, dtype = x.shape, x.dtype
209
+ x = self.preprocess(x)
210
+
211
+ self.init_embed_(x)
212
+
213
+ embed_ind = self.quantize(x)
214
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
215
+ embed_ind = self.postprocess_emb(embed_ind, shape)
216
+ quantize = self.dequantize(embed_ind)
217
+
218
+ if self.training:
219
+ # We do the expiry of code at that point as buffers are in sync
220
+ # and all the workers will take the same decision.
221
+ self.expire_codes_(x)
222
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
223
+ embed_sum = x.t() @ embed_onehot
224
+ ema_inplace(self.embed_avg, embed_sum.t().contiguous(), self.decay)
225
+ cluster_size = (
226
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
227
+ * self.cluster_size.sum()
228
+ )
229
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
230
+ self.embed.data.copy_(embed_normalized)
231
+
232
+ return quantize, embed_ind
233
+
234
+
235
+ class VectorQuantization(nn.Module):
236
+ """Vector quantization implementation.
237
+ Currently supports only euclidean distance.
238
+ Args:
239
+ dim (int): Dimension
240
+ codebook_size (int): Codebook size
241
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
242
+ decay (float): Decay for exponential moving average over the codebooks.
243
+ epsilon (float): Epsilon value for numerical stability.
244
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
245
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
246
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
247
+ that have an exponential moving average cluster size less than the specified threshold with
248
+ randomly selected vector from the current batch.
249
+ commitment_weight (float): Weight for commitment loss.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ dim: int,
255
+ codebook_size: int,
256
+ codebook_dim: tp.Optional[int] = None,
257
+ decay: float = 0.99,
258
+ epsilon: float = 1e-5,
259
+ kmeans_init: bool = True,
260
+ kmeans_iters: int = 50,
261
+ threshold_ema_dead_code: int = 2,
262
+ commitment_weight: float = 1.0,
263
+ ):
264
+ super().__init__()
265
+ _codebook_dim: int = default(codebook_dim, dim)
266
+
267
+ requires_projection = _codebook_dim != dim
268
+ self.project_in = (
269
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
270
+ )
271
+ self.project_out = (
272
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
273
+ )
274
+
275
+ self.epsilon = epsilon
276
+ self.commitment_weight = commitment_weight
277
+
278
+ self._codebook = EuclideanCodebook(
279
+ dim=_codebook_dim,
280
+ codebook_size=codebook_size,
281
+ kmeans_init=kmeans_init,
282
+ kmeans_iters=kmeans_iters,
283
+ decay=decay,
284
+ epsilon=epsilon,
285
+ threshold_ema_dead_code=threshold_ema_dead_code,
286
+ )
287
+ self.codebook_size = codebook_size
288
+
289
+ @property
290
+ def codebook(self):
291
+ return self._codebook.embed
292
+
293
+ def encode(self, x):
294
+ # x = rearrange(x, "b d n -> b n d")
295
+ x = self.project_in(x)
296
+ embed_in = self._codebook.encode(x)
297
+ return embed_in
298
+
299
+ def decode(self, embed_ind):
300
+ quantize = self._codebook.decode(embed_ind)
301
+ quantize = self.project_out(quantize)
302
+ # quantize = rearrange(quantize, "b n d -> b d n")
303
+ return quantize
304
+
305
+ def forward(self, x):
306
+ device = x.device
307
+ x = self.project_in(x)
308
+
309
+ quantize, embed_ind = self._codebook(x)
310
+
311
+ if self.training:
312
+ quantize = x + (quantize - x).detach()
313
+
314
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
315
+
316
+ if self.training:
317
+ if self.commitment_weight > 0:
318
+ commit_loss = F.mse_loss(quantize.detach(), x)
319
+ loss = loss + commit_loss * self.commitment_weight
320
+
321
+ quantize = self.project_out(quantize)
322
+ # quantize = rearrange(quantize, "b n d -> b d n")
323
+ return quantize, embed_ind, loss
324
+
325
+
326
+ class ResidualVectorQuantization(nn.Module):
327
+ """Residual vector quantization implementation.
328
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
329
+ """
330
+
331
+ def __init__(self, *, num_quantizers, codebook_size, **kwargs):
332
+ super().__init__()
333
+ if isinstance(codebook_size, int):
334
+ codebook_size = [codebook_size] * num_quantizers
335
+ elif len(codebook_size) < num_quantizers:
336
+ codebook_size += [codebook_size[-1]] * (num_quantizers - len(codebook_size))
337
+ self.layers = nn.ModuleList(
338
+ [
339
+ VectorQuantization(codebook_size=codebook_size[i], **kwargs)
340
+ for i in range(num_quantizers)
341
+ ]
342
+ )
343
+
344
+ def forward(
345
+ self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
346
+ ):
347
+ quantized_out = 0.0
348
+ residual = x
349
+
350
+ all_losses = []
351
+ all_indices = []
352
+ out_quantized = []
353
+
354
+ n_q = n_q or len(self.layers)
355
+
356
+ for i, layer in enumerate(self.layers[:n_q]):
357
+ quantized, indices, loss = layer(residual)
358
+ residual = residual - quantized
359
+ quantized_out = quantized_out + quantized
360
+
361
+ all_indices.append(indices)
362
+ all_losses.append(loss)
363
+ if layers and i in layers:
364
+ out_quantized.append(quantized_out)
365
+
366
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
367
+ return quantized_out, out_indices, out_losses, out_quantized
368
+
369
+ def encode(
370
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
371
+ ) -> torch.Tensor:
372
+ residual = x
373
+ all_indices = []
374
+ n_q = len(self.layers) if n_q is None else n_q
375
+ st = 0 if st is None else st
376
+ for layer in self.layers[st:n_q]:
377
+ indices = layer.encode(residual)
378
+ quantized = layer.decode(indices)
379
+ residual = residual - quantized
380
+ all_indices.append(indices)
381
+ out_indices = torch.stack(all_indices)
382
+ return out_indices
383
+
384
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
385
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
386
+ for i, indices in enumerate(q_indices):
387
+ layer = self.layers[st + i]
388
+ quantized = layer.decode(indices)
389
+ quantized_out = quantized_out + quantized
390
+ return quantized_out
391
+
392
+
393
+ class ResidualVectorQuantizer(nn.Module):
394
+ """Residual Vector Quantizer.
395
+ Args:
396
+ dimension (int): Dimension of the codebooks.
397
+ n_q (int): Number of residual vector quantizers used.
398
+ bins (int): Codebook size.
399
+ decay (float): Decay for exponential moving average over the codebooks.
400
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
401
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
402
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
403
+ that have an exponential moving average cluster size less than the specified threshold with
404
+ randomly selected vector from the current batch.
405
+ """
406
+
407
+ def __init__(
408
+ self,
409
+ dimension: int = 256,
410
+ n_q: int = 8,
411
+ bins: int | list = 1024,
412
+ decay: float = 0.99,
413
+ kmeans_init: bool = True,
414
+ kmeans_iters: int = 50,
415
+ threshold_ema_dead_code: int = 2,
416
+ ):
417
+ super().__init__()
418
+ self.n_q = n_q
419
+ self.dimension = dimension
420
+ self.bins = bins
421
+ self.decay = decay
422
+ self.kmeans_init = kmeans_init
423
+ self.kmeans_iters = kmeans_iters
424
+ self.threshold_ema_dead_code = threshold_ema_dead_code
425
+ self.vq = ResidualVectorQuantization(
426
+ dim=self.dimension,
427
+ codebook_size=self.bins,
428
+ num_quantizers=self.n_q,
429
+ decay=self.decay,
430
+ kmeans_init=self.kmeans_init,
431
+ kmeans_iters=self.kmeans_iters,
432
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
433
+ )
434
+
435
+ def forward(
436
+ self,
437
+ x: torch.Tensor,
438
+ n_q: tp.Optional[int] = None,
439
+ layers: tp.Optional[list] = None,
440
+ ):
441
+ """Residual vector quantization on the given input tensor.
442
+ Args:
443
+ x (torch.Tensor): Input tensor.
444
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
445
+ layers (list): Layer that need to return quantized. Defalt: None.
446
+ Returns:
447
+ QuantizedResult:
448
+ The quantized (or approximately quantized) representation with
449
+ the associated numbert quantizers and layer quantized required to return.
450
+ """
451
+ n_q = n_q if n_q else self.n_q
452
+ quantized, codes, commit_loss, quantized_list = self.vq(
453
+ x, n_q=n_q, layers=layers
454
+ )
455
+ return quantized, codes, torch.mean(commit_loss), quantized_list
456
+
457
+ def encode(
458
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
459
+ ) -> torch.Tensor:
460
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
461
+ The RVQ encode method sets the appropriate number of quantizer to use
462
+ and returns indices for each quantizer.
463
+ Args:
464
+ x (torch.Tensor): Input tensor.
465
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
466
+ st (int): Start to encode input from which layers. Default: 0.
467
+ """
468
+ n_q = n_q if n_q else self.n_q
469
+ st = st or 0
470
+ codes = self.vq.encode(x, n_q=n_q, st=st)
471
+ return codes
472
+
473
+ def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
474
+ """Decode the given codes to the quantized representation.
475
+ Args:
476
+ codes (torch.Tensor): Input indices for each quantizer.
477
+ st (int): Start to decode input codes from which layers. Default: 0.
478
+ """
479
+ quantized = self.vq.decode(codes, st=st)
480
+ return quantized