shadowlilac commited on
Commit
9602b54
·
verified ·
1 Parent(s): ea4b70d

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_mimo_v2.py +247 -0
  2. modeling_mimo_v2.py +1878 -0
configuration_mimo_v2.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright 2026 Xiaomi Corporation.
4
+ # Copyright 2026 The HuggingFace Inc. team.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from copy import deepcopy
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.modeling_rope_utils import rope_config_validation
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ _MIMOV2_ATTENTION_PROJECTION_LAYOUTS = {"split", "fused_qkv"}
29
+
30
+ _MIMOV2_SPLIT_TP_PLAN = {
31
+ "layers.*.self_attn.q_proj": "colwise",
32
+ "layers.*.self_attn.k_proj": "colwise",
33
+ "layers.*.self_attn.v_proj": "colwise",
34
+ "layers.*.self_attn.o_proj": "rowwise",
35
+ "layers.*.mlp.gate_proj": "colwise",
36
+ "layers.*.mlp.up_proj": "colwise",
37
+ "layers.*.mlp.down_proj": "rowwise",
38
+ }
39
+
40
+ _MIMOV2_FUSED_QKV_TP_PLAN = {
41
+ "layers.*.self_attn.qkv_proj": "colwise",
42
+ "layers.*.self_attn.o_proj": "rowwise",
43
+ "layers.*.mlp.gate_proj": "colwise",
44
+ "layers.*.mlp.up_proj": "colwise",
45
+ "layers.*.mlp.down_proj": "rowwise",
46
+ }
47
+
48
+ _MIMOV2_PP_PLAN = {
49
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
50
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
51
+ "norm": (["hidden_states"], ["hidden_states"]),
52
+ }
53
+
54
+
55
+ def _to_plain_dict(value):
56
+ if value is None:
57
+ return {}
58
+ if isinstance(value, dict):
59
+ return deepcopy(value)
60
+ if hasattr(value, "to_dict"):
61
+ return deepcopy(value.to_dict())
62
+ if hasattr(value, "__dict__"):
63
+ return deepcopy(vars(value))
64
+ raise TypeError(f"Unsupported config value type: {type(value)!r}")
65
+
66
+
67
+ class MiMoV2Config(PretrainedConfig):
68
+
69
+ model_type = "mimo_v2"
70
+ keys_to_ignore_at_inference = ["past_key_values"]
71
+
72
+ base_model_tp_plan = _MIMOV2_SPLIT_TP_PLAN
73
+ base_model_pp_plan = _MIMOV2_PP_PLAN
74
+
75
+ attribute_map = {
76
+ "num_local_experts": "n_routed_experts",
77
+ }
78
+
79
+ def __init__(
80
+ self,
81
+ vocab_size=151936,
82
+ hidden_size=4096,
83
+ intermediate_size=22016,
84
+ num_hidden_layers=32,
85
+ num_attention_heads=32,
86
+ num_key_value_heads=32,
87
+ hidden_act="silu",
88
+ max_position_embeddings=32768,
89
+ initializer_range=0.02,
90
+ layernorm_epsilon=1e-6,
91
+ use_cache=True,
92
+ tie_word_embeddings=False,
93
+ rope_theta=10000.0,
94
+ rope_scaling=None,
95
+ attention_dropout=0.0,
96
+ attention_bias=False,
97
+ attention_value_scale=None,
98
+ head_dim=None,
99
+ v_head_dim=None,
100
+ swa_num_attention_heads=None,
101
+ swa_num_key_value_heads=None,
102
+ swa_head_dim=None,
103
+ swa_v_head_dim=None,
104
+ swa_rope_theta=None,
105
+ sliding_window=None,
106
+ sliding_window_size=None,
107
+ add_full_attention_sink_bias=False,
108
+ add_swa_attention_sink_bias=False,
109
+ hybrid_block_size=None,
110
+ hybrid_layer_pattern=None,
111
+ partial_rotary_factor=1.0,
112
+ n_routed_experts=None,
113
+ moe_intermediate_size=None,
114
+ num_experts_per_tok=None,
115
+ routed_scaling_factor=None,
116
+ scoring_func="sigmoid",
117
+ topk_method="noaux_tc",
118
+ n_group=None,
119
+ topk_group=None,
120
+ norm_topk_prob=True,
121
+ moe_layer_freq=None,
122
+ attention_projection_layout="split",
123
+ vision_config=None,
124
+ audio_config=None,
125
+ processor_config=None,
126
+ image_token_id=None,
127
+ video_token_id=None,
128
+ vision_start_token_id=None,
129
+ vision_end_token_id=None,
130
+ vision_model_type=None,
131
+ **kwargs,
132
+ ):
133
+ rope_parameters = kwargs.pop("rope_parameters", None)
134
+ if rope_scaling is None and rope_parameters is not None:
135
+ rope_scaling = rope_parameters
136
+
137
+ if attention_projection_layout is None:
138
+ attention_projection_layout = "split"
139
+ if attention_projection_layout not in _MIMOV2_ATTENTION_PROJECTION_LAYOUTS:
140
+ raise ValueError(f"Unsupported MiMoV2 attention projection layout: {attention_projection_layout}")
141
+
142
+ self.attention_projection_layout = attention_projection_layout
143
+ self.base_model_tp_plan = (
144
+ _MIMOV2_FUSED_QKV_TP_PLAN.copy()
145
+ if attention_projection_layout == "fused_qkv"
146
+ else _MIMOV2_SPLIT_TP_PLAN.copy()
147
+ )
148
+ self.base_model_pp_plan = _MIMOV2_PP_PLAN.copy()
149
+
150
+ self.vocab_size = vocab_size
151
+ self.max_position_embeddings = max_position_embeddings
152
+ self.hidden_size = hidden_size
153
+ self.intermediate_size = intermediate_size
154
+ self.num_hidden_layers = num_hidden_layers
155
+ self.num_attention_heads = num_attention_heads
156
+
157
+ if num_key_value_heads is None:
158
+ num_key_value_heads = num_attention_heads
159
+ if num_attention_heads % num_key_value_heads != 0:
160
+ raise ValueError("num_attention_heads must be divisible by num_key_value_heads")
161
+
162
+ self.num_key_value_heads = num_key_value_heads
163
+ self.hidden_act = hidden_act
164
+ self.initializer_range = initializer_range
165
+ self.layernorm_epsilon = layernorm_epsilon
166
+ self.use_cache = use_cache
167
+ self.rope_theta = rope_theta
168
+ self.rope_scaling = rope_scaling
169
+ self.attention_dropout = attention_dropout
170
+ self.attention_bias = attention_bias
171
+ self.attention_value_scale = attention_value_scale
172
+
173
+ self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
174
+ self.v_head_dim = v_head_dim if v_head_dim is not None else self.head_dim
175
+ self.swa_num_attention_heads = (
176
+ swa_num_attention_heads if swa_num_attention_heads is not None else num_attention_heads
177
+ )
178
+ self.swa_num_key_value_heads = (
179
+ swa_num_key_value_heads if swa_num_key_value_heads is not None else num_key_value_heads
180
+ )
181
+ if self.swa_num_attention_heads % self.swa_num_key_value_heads != 0:
182
+ raise ValueError("swa_num_attention_heads must be divisible by swa_num_key_value_heads")
183
+ self.swa_head_dim = swa_head_dim if swa_head_dim is not None else self.head_dim
184
+ self.swa_v_head_dim = swa_v_head_dim if swa_v_head_dim is not None else self.swa_head_dim
185
+ self.swa_rope_theta = swa_rope_theta if swa_rope_theta is not None else rope_theta
186
+
187
+ if sliding_window is None:
188
+ sliding_window = sliding_window_size
189
+ self.sliding_window = sliding_window
190
+ self.sliding_window_size = sliding_window_size if sliding_window_size is not None else sliding_window
191
+ self.add_full_attention_sink_bias = add_full_attention_sink_bias
192
+ self.add_swa_attention_sink_bias = add_swa_attention_sink_bias
193
+
194
+ if hybrid_block_size is not None and hybrid_layer_pattern is None:
195
+ hybrid_layer_pattern = [0 if ((i + 1) % hybrid_block_size == 0) else 1 for i in range(num_hidden_layers)]
196
+ elif hybrid_layer_pattern is None:
197
+ hybrid_layer_pattern = [0] * num_hidden_layers
198
+ if len(hybrid_layer_pattern) != num_hidden_layers:
199
+ raise ValueError("hybrid_layer_pattern length must match num_hidden_layers")
200
+ self.hybrid_block_size = hybrid_block_size
201
+ self.hybrid_layer_pattern = hybrid_layer_pattern
202
+
203
+ self.partial_rotary_factor = partial_rotary_factor
204
+
205
+ self.n_routed_experts = n_routed_experts
206
+ self.moe_intermediate_size = moe_intermediate_size if moe_intermediate_size is not None else intermediate_size
207
+ self.num_experts_per_tok = num_experts_per_tok
208
+ self.routed_scaling_factor = routed_scaling_factor
209
+ self.scoring_func = scoring_func
210
+ self.topk_method = topk_method
211
+ self.n_group = n_group
212
+ self.topk_group = topk_group
213
+ self.norm_topk_prob = norm_topk_prob
214
+ if isinstance(moe_layer_freq, int):
215
+ moe_layer_freq = [moe_layer_freq > 0 and i % moe_layer_freq == 0 for i in range(num_hidden_layers)]
216
+ elif moe_layer_freq is None:
217
+ moe_layer_freq = [False] * num_hidden_layers
218
+ if len(moe_layer_freq) != num_hidden_layers:
219
+ raise ValueError("moe_layer_freq length must match num_hidden_layers")
220
+ self.moe_layer_freq = moe_layer_freq
221
+
222
+ self.vision_config = _to_plain_dict(vision_config)
223
+ self.audio_config = _to_plain_dict(audio_config)
224
+ self.processor_config = _to_plain_dict(processor_config)
225
+ self.image_token_id = image_token_id
226
+ self.video_token_id = video_token_id
227
+ self.vision_start_token_id = vision_start_token_id
228
+ self.vision_end_token_id = vision_end_token_id
229
+ self.vision_model_type = vision_model_type
230
+ self.audio_token_id = self.processor_config.get("audio_token_id", None) if self.processor_config else None
231
+ self.audio_start_token_id = (
232
+ self.processor_config.get("audio_start_token_id", None) if self.processor_config else None
233
+ )
234
+ self.audio_end_token_id = (
235
+ self.processor_config.get("audio_end_token_id", None) if self.processor_config else None
236
+ )
237
+
238
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
239
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
240
+ rope_config_validation(self)
241
+
242
+ super().__init__(
243
+ tie_word_embeddings=tie_word_embeddings,
244
+ **kwargs,
245
+ )
246
+
247
+ __all__ = ["MiMoV2Config"]
modeling_mimo_v2.py ADDED
@@ -0,0 +1,1878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright 2026 Xiaomi Corporation.
4
+ # Copyright 2026 The HuggingFace Inc. team.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import math
19
+ from copy import copy
20
+ from types import SimpleNamespace
21
+ from typing import Callable, Optional, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache
29
+ from transformers.configuration_utils import PretrainedConfig
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.integrations import use_kernel_forward_from_hub
32
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
33
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
34
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
37
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
38
+ from transformers.processing_utils import Unpack
39
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
40
+
41
+ from .configuration_mimo_v2 import MiMoV2Config
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ def rotate_half(x):
48
+ """Rotates half the hidden dims of the input."""
49
+ x1 = x[..., : x.shape[-1] // 2]
50
+ x2 = x[..., x.shape[-1] // 2 :]
51
+ return torch.cat((-x2, x1), dim=-1)
52
+
53
+
54
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
55
+ """Applies rotary position embedding to query and key tensors."""
56
+ cos = cos.unsqueeze(unsqueeze_dim)
57
+ sin = sin.unsqueeze(unsqueeze_dim)
58
+ q_embed = (q * cos) + (rotate_half(q) * sin)
59
+ k_embed = (k * cos) + (rotate_half(k) * sin)
60
+ return q_embed, k_embed
61
+
62
+
63
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
64
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
65
+ if n_rep == 1:
66
+ return hidden_states
67
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
68
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
69
+
70
+
71
+ def eager_attention_forward(
72
+ module: nn.Module,
73
+ query: torch.Tensor,
74
+ key: torch.Tensor,
75
+ value: torch.Tensor,
76
+ attention_mask: Optional[torch.Tensor],
77
+ scaling: float,
78
+ dropout: float = 0.0,
79
+ sinks: Optional[torch.Tensor] = None,
80
+ **kwargs,
81
+ ):
82
+ key_states = repeat_kv(key, module.num_key_value_groups)
83
+ value_states = repeat_kv(value, module.num_key_value_groups)
84
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
85
+ if attention_mask is not None:
86
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
87
+ attn_weights = attn_weights + causal_mask
88
+
89
+ if sinks is not None:
90
+ sinks = module.attention_sink_bias.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
91
+ attn_weights = torch.cat([attn_weights, sinks], dim=-1)
92
+
93
+ attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values
94
+ probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
95
+
96
+ if sinks is not None:
97
+ probs = probs[..., :-1]
98
+
99
+ attn_weights = nn.functional.dropout(probs, p=dropout, training=module.training)
100
+ attn_output = torch.matmul(attn_weights, value_states)
101
+ attn_output = attn_output.transpose(1, 2).contiguous()
102
+ return attn_output, attn_weights
103
+
104
+
105
+ @use_kernel_forward_from_hub("RMSNorm")
106
+ class MiMoV2RMSNorm(nn.Module):
107
+ def __init__(self, hidden_size, eps=1e-6):
108
+ super().__init__()
109
+ self.weight = nn.Parameter(torch.ones(hidden_size))
110
+ self.variance_epsilon = eps
111
+
112
+ def forward(self, hidden_states):
113
+ input_dtype = hidden_states.dtype
114
+ hidden_states = hidden_states.to(torch.float32)
115
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
116
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
117
+ return self.weight * hidden_states.to(input_dtype)
118
+
119
+
120
+ class MiMoV2MLP(nn.Module):
121
+ def __init__(self, config, intermediate_size=None):
122
+ super().__init__()
123
+ self.config = config
124
+ self.hidden_size = config.hidden_size
125
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
126
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
127
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
128
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
129
+ self.act_fn = ACT2FN[config.hidden_act]
130
+
131
+ def forward(self, hidden_states):
132
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
133
+
134
+
135
+ class MiMoV2MoEGate(nn.Module):
136
+ def __init__(self, config):
137
+ super().__init__()
138
+ self.config = config
139
+ self.top_k = config.num_experts_per_tok
140
+ self.n_routed_experts = config.n_routed_experts
141
+ self.routed_scaling_factor = config.routed_scaling_factor if config.routed_scaling_factor is not None else 1.0
142
+ self.scoring_func = config.scoring_func
143
+ self.topk_method = config.topk_method
144
+ self.n_group = config.n_group
145
+ self.topk_group = config.topk_group
146
+ self.norm_topk_prob = config.norm_topk_prob
147
+ self.gating_dim = config.hidden_size
148
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
149
+ if self.topk_method == "noaux_tc":
150
+ self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
151
+
152
+ def forward(self, hidden_states):
153
+ bsz, seq_len, h = hidden_states.shape
154
+ hidden_states = hidden_states.view(-1, h)
155
+ logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
156
+ if self.scoring_func == "sigmoid":
157
+ scores = logits.sigmoid()
158
+ else:
159
+ raise NotImplementedError(f"Unsupported scoring function for MoE gating: {self.scoring_func}")
160
+
161
+ if self.topk_method == "noaux_tc":
162
+ if self.training:
163
+ raise ValueError("MiMoV2 noaux_tc routing is only implemented for inference.")
164
+ scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
165
+ group_scores = scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
166
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
167
+ group_mask = torch.zeros_like(group_scores)
168
+ group_mask.scatter_(1, group_idx, 1)
169
+ score_mask = (
170
+ group_mask.unsqueeze(-1)
171
+ .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
172
+ .reshape(bsz * seq_len, -1)
173
+ )
174
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf"))
175
+ _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
176
+ topk_weight = scores.gather(1, topk_idx)
177
+ else:
178
+ raise NotImplementedError(f"Unsupported TopK function for MoE gating: {self.topk_method}")
179
+
180
+ if self.top_k > 1 and self.norm_topk_prob:
181
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
182
+ topk_weight = topk_weight / denominator
183
+ topk_weight = topk_weight * self.routed_scaling_factor
184
+ return topk_idx, topk_weight
185
+
186
+
187
+ class MiMoV2MoE(nn.Module):
188
+ def __init__(self, config):
189
+ super().__init__()
190
+ self.config = config
191
+ self.experts = nn.ModuleList(
192
+ [MiMoV2MLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_routed_experts)]
193
+ )
194
+ self.gate = MiMoV2MoEGate(config)
195
+
196
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
197
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
198
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
199
+ expert_mask = expert_mask.permute(2, 0, 1)
200
+
201
+ for expert_idx, expert in enumerate(self.experts):
202
+ mask = expert_mask[expert_idx]
203
+ token_indices, weight_indices = torch.where(mask)
204
+ if token_indices.numel() > 0:
205
+ expert_weights = topk_weights[token_indices, weight_indices]
206
+ expert_input = hidden_states[token_indices]
207
+ expert_output = expert(expert_input)
208
+ final_hidden_states.index_add_(0, token_indices, expert_output * expert_weights.unsqueeze(-1))
209
+
210
+ return final_hidden_states.type(hidden_states.dtype)
211
+
212
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
213
+ orig_shape = hidden_states.shape
214
+ topk_indices, topk_weights = self.gate(hidden_states)
215
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
216
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
217
+ return hidden_states
218
+
219
+
220
+ class MiMoV2Attention(nn.Module):
221
+ """MiMoV2 attention.
222
+
223
+ `projection_layout` only controls how checkpoint weights are named and
224
+ stored: Flash uses separate q/k/v projections, while Pro uses fused qkv.
225
+ The attention computation after projection is shared.
226
+ """
227
+
228
+ def __init__(self, config, is_swa: bool, layer_idx: int, projection_layout: str = "split"):
229
+ super().__init__()
230
+ if projection_layout not in {"split", "fused_qkv"}:
231
+ raise ValueError(f"Unsupported MiMoV2 attention projection layout: {projection_layout}")
232
+
233
+ self.config = config
234
+ self.layer_idx = layer_idx
235
+ self.is_swa = is_swa
236
+ self.is_causal = True
237
+ self.projection_layout = projection_layout
238
+
239
+ default_head_dim = config.hidden_size // config.num_attention_heads
240
+ default_v_head_dim = getattr(config, "v_head_dim", default_head_dim)
241
+
242
+ if is_swa:
243
+ self.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", default_head_dim))
244
+ self.v_head_dim = getattr(config, "swa_v_head_dim", default_v_head_dim)
245
+ self.num_attention_heads = getattr(config, "swa_num_attention_heads", config.num_attention_heads)
246
+ self.num_key_value_heads = getattr(config, "swa_num_key_value_heads", config.num_key_value_heads)
247
+ else:
248
+ self.head_dim = getattr(config, "head_dim", default_head_dim)
249
+ self.v_head_dim = getattr(config, "v_head_dim", self.head_dim)
250
+ self.num_attention_heads = config.num_attention_heads
251
+ self.num_key_value_heads = config.num_key_value_heads
252
+
253
+ self.rope_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0))
254
+ if self.rope_dim % 2 != 0:
255
+ raise ValueError(
256
+ f"MiMoV2 rotary dimension must be even, got {self.rope_dim} from "
257
+ f"head_dim={self.head_dim} and partial_rotary_factor={getattr(config, 'partial_rotary_factor', 1.0)}"
258
+ )
259
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
260
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
261
+ self.scaling = self.head_dim**-0.5
262
+ self.sliding_window = getattr(config, "sliding_window", None) if is_swa else None
263
+ self.q_size = self.num_attention_heads * self.head_dim
264
+ self.k_size = self.num_key_value_heads * self.head_dim
265
+ self.v_size = self.num_key_value_heads * self.v_head_dim
266
+ self.o_hidden_size = self.num_attention_heads * self.v_head_dim
267
+ self.v_scale = getattr(config, "attention_value_scale", None)
268
+ self.attention_sink_bias = (
269
+ nn.Parameter(torch.empty(self.num_attention_heads), requires_grad=False)
270
+ if (
271
+ (getattr(config, "add_full_attention_sink_bias", False) and not is_swa)
272
+ or (getattr(config, "add_swa_attention_sink_bias", False) and is_swa)
273
+ )
274
+ else None
275
+ )
276
+
277
+ attention_bias = getattr(config, "attention_bias", False)
278
+ if self.projection_layout == "fused_qkv":
279
+ self.qkv_proj = nn.Linear(
280
+ config.hidden_size,
281
+ self.q_size + self.k_size + self.v_size,
282
+ bias=attention_bias,
283
+ )
284
+ else:
285
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=attention_bias)
286
+ self.k_proj = nn.Linear(config.hidden_size, self.k_size, bias=attention_bias)
287
+ self.v_proj = nn.Linear(config.hidden_size, self.v_size, bias=attention_bias)
288
+ self.o_proj = nn.Linear(self.o_hidden_size, config.hidden_size, bias=False)
289
+
290
+ def _forward_attention(
291
+ self,
292
+ query_states: torch.Tensor,
293
+ key_states: torch.Tensor,
294
+ value_states: torch.Tensor,
295
+ input_shape: torch.Size,
296
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
297
+ attention_mask: Optional[torch.Tensor],
298
+ past_key_values: Optional[Cache] = None,
299
+ cache_position: Optional[torch.LongTensor] = None,
300
+ position_ids: Optional[torch.LongTensor] = None,
301
+ ) -> tuple[torch.Tensor, torch.Tensor]:
302
+ if self.v_scale is not None:
303
+ value_states = value_states * self.v_scale
304
+
305
+ cos, sin = position_embeddings
306
+ query_rope, query_nope = query_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
307
+ key_rope, key_nope = key_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
308
+ query_rope, key_rope = apply_rotary_pos_emb(query_rope, key_rope, cos, sin)
309
+ query_states = torch.cat([query_rope, query_nope], dim=-1)
310
+ key_states = torch.cat([key_rope, key_nope], dim=-1)
311
+
312
+ if past_key_values is not None:
313
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
314
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
315
+
316
+ attn_implementation = self.config._attn_implementation
317
+ if attn_implementation is not None and attn_implementation.startswith("paged|"):
318
+ raise ValueError(
319
+ "MiMoV2 remote code does not support paged attention cache. "
320
+ "Please use eager, sdpa, flex_attention, or flash_attention_2."
321
+ )
322
+
323
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
324
+ attn_implementation, eager_attention_forward
325
+ )
326
+ if self.attention_sink_bias is not None and attn_implementation == "sdpa":
327
+ logger.warning_once(
328
+ "MiMoV2 attention sink bias is not supported by SDPA; falling back to eager attention for correctness."
329
+ )
330
+ attention_interface = eager_attention_forward
331
+
332
+ attention_kwargs = {
333
+ "dropout": 0.0 if not self.training else self.attention_dropout,
334
+ "scaling": self.scaling,
335
+ "position_ids": position_ids,
336
+ "is_causal": self.is_causal,
337
+ }
338
+ if attention_interface is eager_attention_forward:
339
+ attention_kwargs["sinks"] = self.attention_sink_bias
340
+ else:
341
+ if self.attention_sink_bias is not None:
342
+ attention_kwargs["s_aux"] = self.attention_sink_bias
343
+ if self.sliding_window is not None:
344
+ attention_kwargs["sliding_window"] = self.sliding_window
345
+
346
+ attn_output, attn_weights = attention_interface(
347
+ self,
348
+ query_states,
349
+ key_states,
350
+ value_states,
351
+ attention_mask,
352
+ **attention_kwargs,
353
+ )
354
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
355
+ attn_output = self.o_proj(attn_output)
356
+ return attn_output, attn_weights
357
+
358
+ def forward(
359
+ self,
360
+ hidden_states: torch.Tensor,
361
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
362
+ attention_mask: Optional[torch.Tensor],
363
+ past_key_values: Optional[Cache] = None,
364
+ cache_position: Optional[torch.LongTensor] = None,
365
+ position_ids: Optional[torch.LongTensor] = None,
366
+ **kwargs: Unpack[TransformersKwargs],
367
+ ) -> tuple[torch.Tensor, torch.Tensor]:
368
+ input_shape = hidden_states.shape[:-1]
369
+
370
+ if self.projection_layout == "fused_qkv":
371
+ qkv_states = self.qkv_proj(hidden_states)
372
+ query_states, key_states, value_states = qkv_states.split([self.q_size, self.k_size, self.v_size], dim=-1)
373
+ else:
374
+ query_states = self.q_proj(hidden_states)
375
+ key_states = self.k_proj(hidden_states)
376
+ value_states = self.v_proj(hidden_states)
377
+
378
+ query_states = query_states.view(*input_shape, self.num_attention_heads, self.head_dim).transpose(1, 2)
379
+ key_states = key_states.view(*input_shape, self.num_key_value_heads, self.head_dim).transpose(1, 2)
380
+ value_states = value_states.view(*input_shape, self.num_key_value_heads, self.v_head_dim).transpose(1, 2)
381
+ return self._forward_attention(
382
+ query_states,
383
+ key_states,
384
+ value_states,
385
+ input_shape,
386
+ position_embeddings,
387
+ attention_mask,
388
+ past_key_values=past_key_values,
389
+ cache_position=cache_position,
390
+ position_ids=position_ids,
391
+ )
392
+
393
+
394
+ class MiMoV2DecoderLayer(nn.Module):
395
+ attention_projection_layout = "split"
396
+
397
+ def __init__(self, config, layer_idx: int, attention_projection_layout: Optional[str] = None):
398
+ super().__init__()
399
+ attention_projection_layout = attention_projection_layout or self.attention_projection_layout
400
+ is_swa_layer = config.hybrid_layer_pattern[layer_idx] == 1
401
+ self.attention_type = "sliding_window_attention" if is_swa_layer else "full_attention"
402
+ self.self_attn = MiMoV2Attention(
403
+ config, is_swa_layer, layer_idx, projection_layout=attention_projection_layout
404
+ )
405
+ self.mlp = (
406
+ MiMoV2MoE(config)
407
+ if getattr(config, "n_routed_experts", None) is not None and config.moe_layer_freq[layer_idx]
408
+ else MiMoV2MLP(config)
409
+ )
410
+ self.input_layernorm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
411
+ self.post_attention_layernorm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states: torch.Tensor,
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ position_ids: Optional[torch.LongTensor] = None,
418
+ past_key_values: Optional[Cache] = None,
419
+ use_cache: Optional[bool] = False,
420
+ cache_position: Optional[torch.LongTensor] = None,
421
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
422
+ **kwargs: Unpack[TransformersKwargs],
423
+ ) -> torch.Tensor:
424
+ residual = hidden_states
425
+ hidden_states = self.input_layernorm(hidden_states)
426
+ hidden_states, _ = self.self_attn(
427
+ hidden_states=hidden_states,
428
+ attention_mask=attention_mask,
429
+ position_ids=position_ids,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ cache_position=cache_position,
433
+ position_embeddings=position_embeddings,
434
+ **kwargs,
435
+ )
436
+ hidden_states = residual + hidden_states
437
+
438
+ residual = hidden_states
439
+ hidden_states = self.post_attention_layernorm(hidden_states)
440
+ hidden_states = self.mlp(hidden_states)
441
+ hidden_states = residual + hidden_states
442
+ return hidden_states
443
+
444
+
445
+ class MiMoV2RotaryEmbedding(nn.Module):
446
+ inv_freq: torch.Tensor
447
+
448
+ def __init__(self, config, is_swa: bool, device=None):
449
+ super().__init__()
450
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
451
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default"))
452
+ else:
453
+ self.rope_type = "default"
454
+ self.max_seq_len_cached = config.max_position_embeddings
455
+ self.original_max_seq_len = config.max_position_embeddings
456
+
457
+ self.config = copy(config)
458
+ self.config.rope_parameters = copy(getattr(config, "rope_parameters", None) or {})
459
+ if is_swa:
460
+ self.config.rope_theta = getattr(config, "swa_rope_theta", config.rope_theta)
461
+ self.config.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", None))
462
+ if self.config.rope_parameters:
463
+ self.config.rope_parameters["rope_theta"] = self.config.rope_theta
464
+ self.rope_init_fn = (
465
+ self.compute_default_rope_parameters
466
+ if self.rope_type == "default"
467
+ else ROPE_INIT_FUNCTIONS[self.rope_type]
468
+ )
469
+
470
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
471
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
472
+ self.original_inv_freq = self.inv_freq
473
+
474
+ @staticmethod
475
+ def compute_default_rope_parameters(config, device=None, seq_len=None, layer_type=None):
476
+ config.standardize_rope_params()
477
+ rope_parameters = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
478
+ base = rope_parameters["rope_theta"]
479
+ partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
480
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
481
+ dim = int(head_dim * partial_rotary_factor)
482
+ if dim % 2 != 0:
483
+ raise ValueError(
484
+ f"MiMoV2 rotary dimension must be even, got {dim} from "
485
+ f"head_dim={head_dim} and partial_rotary_factor={partial_rotary_factor}"
486
+ )
487
+ inv_freq = 1.0 / (
488
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
489
+ )
490
+ return inv_freq, 1.0
491
+
492
+ @torch.no_grad()
493
+ @dynamic_rope_update
494
+ def forward(self, x, position_ids):
495
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
496
+ position_ids_expanded = position_ids[:, None, :].float()
497
+
498
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
499
+ with torch.autocast(device_type=device_type, enabled=False):
500
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
501
+ emb = torch.cat((freqs, freqs), dim=-1)
502
+ cos = emb.cos() * self.attention_scaling
503
+ sin = emb.sin() * self.attention_scaling
504
+
505
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
506
+
507
+
508
+ # ---------------------------------------------------------------------------
509
+ # Multimodal helpers
510
+ # ---------------------------------------------------------------------------
511
+
512
+
513
+ def _as_namespace(config_like):
514
+ if config_like is None:
515
+ return SimpleNamespace()
516
+ if isinstance(config_like, dict):
517
+ return SimpleNamespace(**config_like)
518
+ return config_like
519
+
520
+
521
+ def _parse_maybe_list(value: str | int, length: int) -> list[int]:
522
+ if isinstance(value, str) and "-" in value:
523
+ return [int(x) for x in value.split("-")]
524
+ return [int(value)] * length
525
+
526
+
527
+ def _build_speech_embeddings(config) -> nn.ModuleList:
528
+ audio_channels = getattr(config, "audio_channels")
529
+ input_local_dim = getattr(config, "input_local_dim")
530
+ speech_empty_ids = _parse_maybe_list(getattr(config, "speech_zeroemb_idx"), audio_channels)
531
+ speech_vocab_sizes = _parse_maybe_list(getattr(config, "speech_vocab_size"), audio_channels)
532
+ return nn.ModuleList(
533
+ [
534
+ nn.Embedding(speech_vocab_sizes[i], input_local_dim, padding_idx=speech_empty_ids[i])
535
+ for i in range(audio_channels)
536
+ ]
537
+ )
538
+
539
+
540
+ def _pad_and_group_audio_codes(
541
+ audio_codes: torch.Tensor, audio_channels: int, group_size: int
542
+ ) -> torch.Tensor:
543
+ """Slice to `audio_channels`, pad to `group_size` boundary, reshape to [G, group_size, C]."""
544
+ if audio_codes.dim() != 2:
545
+ raise ValueError(f"`audio_codes` must be 2D [T, C], got shape={tuple(audio_codes.shape)}")
546
+ audio_codes = audio_codes[:, :audio_channels]
547
+ T = audio_codes.shape[0]
548
+ padded_T = ((T + group_size - 1) // group_size) * group_size
549
+ if padded_T > T:
550
+ audio_codes = torch.cat([audio_codes, audio_codes[-1:].expand(padded_T - T, -1)], dim=0)
551
+ return audio_codes.reshape(padded_T // group_size, group_size, audio_channels)
552
+
553
+
554
+ def _replace_modal_embeddings_inplace(
555
+ input_ids: torch.Tensor,
556
+ inputs_embeds: torch.Tensor,
557
+ token_id: int | None,
558
+ modal_embeds: torch.Tensor | None,
559
+ ) -> None:
560
+ if token_id is None or modal_embeds is None:
561
+ return
562
+
563
+ if modal_embeds.dim() != 2:
564
+ raise ValueError(f"`modal_embeds` must be 2D [N, H], got shape={tuple(modal_embeds.shape)}")
565
+
566
+ mask = input_ids.eq(token_id)
567
+ num_slots = int(mask.sum().item())
568
+ if num_slots == 0:
569
+ return
570
+
571
+ if modal_embeds.shape[0] != num_slots:
572
+ raise ValueError(
573
+ f"Modal embedding count mismatch for token_id={token_id}: "
574
+ f"found {num_slots} placeholders but got {modal_embeds.shape[0]} embeddings."
575
+ )
576
+
577
+ inputs_embeds[mask] = modal_embeds.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
578
+
579
+
580
+ # ---------------------------------------------------------------------------
581
+ # Vision encoder
582
+ # ---------------------------------------------------------------------------
583
+
584
+
585
+ def _rotate_half_vision(x: torch.Tensor) -> torch.Tensor:
586
+ x1 = x[..., : x.shape[-1] // 2]
587
+ x2 = x[..., x.shape[-1] // 2 :]
588
+ return torch.cat((-x2, x1), dim=-1)
589
+
590
+
591
+ def _apply_rotary_pos_emb_vision(
592
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
593
+ ) -> tuple[torch.Tensor, torch.Tensor]:
594
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
595
+ q, k = q.float(), k.float()
596
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
597
+ q_embed = (q * cos) + (_rotate_half_vision(q) * sin)
598
+ k_embed = (k * cos) + (_rotate_half_vision(k) * sin)
599
+ return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype)
600
+
601
+
602
+ class MiMoVisionRotaryEmbedding(nn.Module):
603
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
604
+ super().__init__()
605
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
606
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
607
+
608
+ def forward(self, seqlen: int) -> torch.Tensor:
609
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
610
+ return torch.outer(seq, self.inv_freq)
611
+
612
+
613
+ class MiMoVisionPatchEmbed(nn.Module):
614
+ def __init__(
615
+ self, patch_size: int = 16, temporal_patch_size: int = 2, in_channels: int = 3, embed_dim: int = 1280
616
+ ):
617
+ super().__init__()
618
+ self.patch_size = patch_size
619
+ self.temporal_patch_size = temporal_patch_size
620
+ self.in_channels = in_channels
621
+ self.embed_dim = embed_dim
622
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
623
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
624
+
625
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
626
+ target_dtype = self.proj.weight.dtype
627
+ hidden_states = hidden_states.view(
628
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
629
+ )
630
+ return self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
631
+
632
+
633
+ class MiMoVisionSwiGLUMLP(nn.Module):
634
+ def __init__(self, dim: int, intermediate_dim: int, hidden_act: str = "silu"):
635
+ super().__init__()
636
+ self.gate_proj = nn.Linear(dim, intermediate_dim, bias=True)
637
+ self.up_proj = nn.Linear(dim, intermediate_dim, bias=True)
638
+ self.down_proj = nn.Linear(intermediate_dim, dim, bias=True)
639
+ self.act_fn = ACT2FN[hidden_act]
640
+
641
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
642
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
643
+
644
+
645
+ class MiMoVisionAttention(nn.Module):
646
+ def __init__(
647
+ self,
648
+ dim: int,
649
+ num_heads: int,
650
+ num_kv_heads: int | None = None,
651
+ head_dim: int | None = None,
652
+ use_sinks: bool = False,
653
+ window_size: int = -1,
654
+ ):
655
+ super().__init__()
656
+ self.dim = dim
657
+ self.num_heads = num_heads
658
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
659
+ self.head_dim = head_dim if head_dim is not None else dim // num_heads
660
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
661
+ self.scaling = self.head_dim**-0.5
662
+ self.window_size = window_size
663
+
664
+ qkv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim
665
+ self.qkv = nn.Linear(dim, qkv_dim, bias=True)
666
+ self.proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=True)
667
+ self.sinks = nn.Parameter(torch.zeros(self.num_heads)) if use_sinks else None
668
+
669
+ def _build_window_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
670
+ if self.window_size <= 0:
671
+ return None
672
+ row_idx = torch.arange(seq_len, device=device).unsqueeze(1)
673
+ col_idx = torch.arange(seq_len, device=device).unsqueeze(0)
674
+ mask = torch.zeros(seq_len, seq_len, device=device, dtype=dtype)
675
+ mask = mask.masked_fill((row_idx - col_idx).abs() > self.window_size, float("-inf"))
676
+ return mask
677
+
678
+ def forward(
679
+ self,
680
+ hidden_states: torch.Tensor,
681
+ cu_seqlens: torch.Tensor,
682
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
683
+ full_attn: bool = False,
684
+ ) -> torch.Tensor:
685
+ seq_len = hidden_states.shape[0]
686
+ qkv = self.qkv(hidden_states)
687
+
688
+ q_dim = self.num_heads * self.head_dim
689
+ kv_dim = self.num_kv_heads * self.head_dim
690
+ q = qkv[:, :q_dim].view(seq_len, self.num_heads, self.head_dim)
691
+ k = qkv[:, q_dim : q_dim + kv_dim].view(seq_len, self.num_kv_heads, self.head_dim)
692
+ v = qkv[:, q_dim + kv_dim :].view(seq_len, self.num_kv_heads, self.head_dim)
693
+
694
+ cos, sin = position_embeddings
695
+ q, k = _apply_rotary_pos_emb_vision(q, k, cos, sin)
696
+
697
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
698
+ q_chunks = torch.split(q, lengths.tolist(), dim=0)
699
+ k_chunks = torch.split(k, lengths.tolist(), dim=0)
700
+ v_chunks = torch.split(v, lengths.tolist(), dim=0)
701
+
702
+ outputs = []
703
+ for q_c, k_c, v_c in zip(q_chunks, k_chunks, v_chunks):
704
+ q_c = q_c.unsqueeze(0).transpose(1, 2)
705
+ k_c = k_c.unsqueeze(0).transpose(1, 2)
706
+ v_c = v_c.unsqueeze(0).transpose(1, 2)
707
+
708
+ if self.num_kv_groups > 1:
709
+ k_c = k_c.repeat_interleave(self.num_kv_groups, dim=1)
710
+ v_c = v_c.repeat_interleave(self.num_kv_groups, dim=1)
711
+
712
+ attn_mask = None
713
+ if not full_attn:
714
+ attn_mask = self._build_window_mask(q_c.shape[2], q_c.device, q_c.dtype)
715
+
716
+ if self.sinks is not None:
717
+ sink_bias = torch.zeros(
718
+ 1, self.num_heads, q_c.shape[2], k_c.shape[2], device=q_c.device, dtype=q_c.dtype
719
+ )
720
+ sink_bias[..., 0] = self.sinks.view(1, self.num_heads, 1)
721
+ attn_mask = sink_bias if attn_mask is None else attn_mask + sink_bias
722
+
723
+ attn_out = F.scaled_dot_product_attention(q_c, k_c, v_c, attn_mask=attn_mask, scale=self.scaling)
724
+ outputs.append(attn_out.squeeze(0).transpose(0, 1))
725
+
726
+ attn_output = torch.cat(outputs, dim=0)
727
+ attn_output = attn_output.reshape(seq_len, -1)
728
+ return self.proj(attn_output)
729
+
730
+
731
+ class MiMoVisionBlock(nn.Module):
732
+ def __init__(
733
+ self,
734
+ dim: int,
735
+ intermediate_dim: int,
736
+ num_heads: int,
737
+ num_kv_heads: int | None = None,
738
+ head_dim: int | None = None,
739
+ hidden_act: str = "silu",
740
+ rms_norm_eps: float = 1e-6,
741
+ use_sinks: bool = False,
742
+ window_size: int = -1,
743
+ ):
744
+ super().__init__()
745
+ self.norm1 = nn.RMSNorm(dim, eps=rms_norm_eps)
746
+ self.norm2 = nn.RMSNorm(dim, eps=rms_norm_eps)
747
+ self.attn = MiMoVisionAttention(
748
+ dim=dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim,
749
+ use_sinks=use_sinks, window_size=window_size,
750
+ )
751
+ self.mlp = MiMoVisionSwiGLUMLP(dim=dim, intermediate_dim=intermediate_dim, hidden_act=hidden_act)
752
+
753
+ def forward(
754
+ self,
755
+ hidden_states: torch.Tensor,
756
+ cu_seqlens: torch.Tensor,
757
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
758
+ full_attn: bool = False,
759
+ ) -> torch.Tensor:
760
+ hidden_states = hidden_states + self.attn(
761
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens,
762
+ position_embeddings=position_embeddings, full_attn=full_attn,
763
+ )
764
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
765
+ return hidden_states
766
+
767
+
768
+ class MiMoVisionPatchMerger(nn.Module):
769
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2):
770
+ super().__init__()
771
+ self.hidden_size = context_dim * (spatial_merge_size**2)
772
+ self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
773
+ self.mlp = nn.Sequential(
774
+ nn.Linear(self.hidden_size, self.hidden_size),
775
+ nn.GELU(),
776
+ nn.Linear(self.hidden_size, dim),
777
+ )
778
+
779
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
780
+ return self.mlp(self.ln_q(x).view(-1, self.hidden_size))
781
+
782
+
783
+ class MiMoVisionTransformer(nn.Module):
784
+ def __init__(self, config):
785
+ super().__init__()
786
+ self.config = config
787
+ hidden_size = config.hidden_size
788
+ depth = config.depth
789
+ num_heads = config.num_heads
790
+ num_kv_heads = getattr(config, "num_key_value_heads", num_heads)
791
+ head_dim = getattr(config, "qk_channels", 64)
792
+ spatial_merge_size = getattr(config, "spatial_merge_size", 2)
793
+ rms_norm_eps = getattr(config, "rms_norm_eps", 1e-6)
794
+ self.fullatt_block_indexes = getattr(config, "fullatt_block_indexes", [])
795
+ use_sink = getattr(config, "use_sink", False)
796
+ visual_token_window_size = getattr(config, "visual_token_window_size", -1)
797
+ self.vit_window_attn_types = getattr(config, "vit_window_attn_types", None) or [-1] * depth
798
+
799
+ self.spatial_merge_size = spatial_merge_size
800
+ self.spatial_merge_unit = spatial_merge_size * spatial_merge_size
801
+
802
+ self.patch_embed = MiMoVisionPatchEmbed(
803
+ patch_size=config.patch_size,
804
+ temporal_patch_size=config.temporal_patch_size,
805
+ in_channels=getattr(config, "in_channels", None) or getattr(config, "in_chans", 3),
806
+ embed_dim=hidden_size,
807
+ )
808
+
809
+ self.rotary_pos_emb = MiMoVisionRotaryEmbedding(head_dim // 2)
810
+
811
+ self.blocks = nn.ModuleList(
812
+ [
813
+ MiMoVisionBlock(
814
+ dim=hidden_size,
815
+ intermediate_dim=config.intermediate_size,
816
+ num_heads=num_heads,
817
+ num_kv_heads=num_kv_heads,
818
+ head_dim=head_dim,
819
+ hidden_act=config.hidden_act,
820
+ rms_norm_eps=rms_norm_eps,
821
+ use_sinks=use_sink and (i not in self.fullatt_block_indexes),
822
+ window_size=visual_token_window_size,
823
+ )
824
+ for i in range(depth)
825
+ ]
826
+ )
827
+
828
+ self.merger = MiMoVisionPatchMerger(
829
+ dim=config.out_hidden_size,
830
+ context_dim=hidden_size,
831
+ spatial_merge_size=spatial_merge_size,
832
+ )
833
+
834
+ @property
835
+ def dtype(self) -> torch.dtype:
836
+ return self.patch_embed.proj.weight.dtype
837
+
838
+ def apply_index(self, tensor: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
839
+ tensor = tensor.unflatten(0, (-1, self.spatial_merge_unit))
840
+ tensor = tensor[index]
841
+ return tensor.flatten(0, 1)
842
+
843
+ def get_window_index_1d(self, grid_thw: torch.Tensor, col: bool = True) -> torch.Tensor:
844
+ window_index = []
845
+ window_index_id = 0
846
+ for grid_t, grid_h, grid_w in grid_thw:
847
+ llm_grid_h = grid_h // self.spatial_merge_size
848
+ llm_grid_w = grid_w // self.spatial_merge_size
849
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
850
+ index_new = index.transpose(1, 2).reshape(-1) if col else index.reshape(-1)
851
+ window_index.append(index_new + window_index_id)
852
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
853
+ return torch.cat(window_index, dim=0)
854
+
855
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
856
+ pos_ids = []
857
+ for t, h, w in grid_thw:
858
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
859
+ hpos_ids = hpos_ids.reshape(
860
+ h // self.spatial_merge_size, self.spatial_merge_size,
861
+ w // self.spatial_merge_size, self.spatial_merge_size,
862
+ )
863
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
864
+
865
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
866
+ wpos_ids = wpos_ids.reshape(
867
+ h // self.spatial_merge_size, self.spatial_merge_size,
868
+ w // self.spatial_merge_size, self.spatial_merge_size,
869
+ )
870
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
871
+
872
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
873
+ pos_ids = torch.cat(pos_ids, dim=0)
874
+ max_grid_size = grid_thw[:, 1:].max()
875
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
876
+ return rotary_pos_emb_full[pos_ids].flatten(1)
877
+
878
+ def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
879
+ x = pixel_values.to(device=self.patch_embed.proj.weight.device, dtype=self.dtype)
880
+ x = self.patch_embed(x)
881
+
882
+ rotary_emb = self.rot_pos_emb(grid_thw)
883
+ rotary_emb = rotary_emb.to(device=x.device)
884
+ emb = torch.cat((rotary_emb, rotary_emb), dim=-1)
885
+
886
+ window_index_1d_col = self.get_window_index_1d(grid_thw, col=True).to(device=x.device)
887
+ reverse_window_index_1d_col = torch.argsort(window_index_1d_col).to(device=x.device)
888
+
889
+ row_based_embeddings = (emb.cos(), emb.sin())
890
+ col_emb = self.apply_index(emb, window_index_1d_col)
891
+ col_based_embeddings = (col_emb.cos(), col_emb.sin())
892
+
893
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
894
+ dim=0, dtype=torch.int32
895
+ )
896
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(device=x.device)
897
+
898
+ for i, blk in enumerate(self.blocks):
899
+ window_attn_type = self.vit_window_attn_types[i]
900
+
901
+ if window_attn_type == 1 and (i == 0 or self.vit_window_attn_types[i - 1] != 1):
902
+ x = self.apply_index(x, window_index_1d_col)
903
+
904
+ if i > 0 and window_attn_type != 1 and self.vit_window_attn_types[i - 1] == 1:
905
+ x = self.apply_index(x, reverse_window_index_1d_col)
906
+
907
+ position_embeddings = col_based_embeddings if window_attn_type == 1 else row_based_embeddings
908
+ full_attn = i in self.fullatt_block_indexes
909
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, full_attn=full_attn)
910
+
911
+ return self.merger(x)
912
+
913
+
914
+ # ---------------------------------------------------------------------------
915
+ # Audio encoder
916
+ # ---------------------------------------------------------------------------
917
+
918
+
919
+ class AudioProjection(nn.Module):
920
+ def __init__(self, input_size: int, hidden_size: int, output_size: int):
921
+ super().__init__()
922
+ self.mlp = nn.Sequential(
923
+ nn.Linear(input_size, hidden_size, bias=False),
924
+ nn.GELU(),
925
+ nn.Linear(hidden_size, output_size, bias=False),
926
+ )
927
+
928
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
929
+ return self.mlp(x)
930
+
931
+
932
+ class MiMoAudioEncoder(nn.Module):
933
+ def __init__(self, config):
934
+ super().__init__()
935
+ self.config = config
936
+
937
+ self.audio_channels = getattr(config, "audio_channels")
938
+ self.group_size = getattr(config, "group_size")
939
+ self.input_local_dim = getattr(config, "input_local_dim")
940
+ self.out_hidden_size = getattr(config, "out_hidden_size")
941
+ self.input_full_attention = getattr(config, "input_full_attention", True)
942
+ self.audio_segment_size = getattr(config, "audio_segment_size", 6000)
943
+
944
+ input_local_config = Qwen2Config(
945
+ hidden_size=getattr(config, "input_local_dim"),
946
+ num_hidden_layers=getattr(config, "input_local_layers"),
947
+ num_attention_heads=getattr(config, "input_local_attn_heads"),
948
+ num_key_value_heads=getattr(config, "input_local_attn_heads"),
949
+ intermediate_size=getattr(config, "input_local_intermediate_size"),
950
+ attention_dropout=getattr(config, "input_local_hidden_dropout", 0.0),
951
+ rope_theta=getattr(config, "rope_theta", 640000.0),
952
+ partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
953
+ )
954
+ self.input_local_transformer = Qwen2Model(input_local_config)
955
+
956
+ if not getattr(config, "add_post_norm", True):
957
+ self.input_local_transformer.norm = nn.Identity()
958
+
959
+ proj_in = self.input_local_dim * self.group_size
960
+ projection_layers = getattr(config, "projection_layers", 2)
961
+ if projection_layers == 1:
962
+ self.projection = nn.Linear(proj_in, self.out_hidden_size, bias=False)
963
+ elif projection_layers == 2:
964
+ self.projection = AudioProjection(proj_in, proj_in * 4, self.out_hidden_size)
965
+ else:
966
+ raise ValueError(f"Unsupported projection_layers={projection_layers}, expected 1 or 2.")
967
+
968
+ def _apply_speech_embeddings(self, audio_codes: torch.Tensor, speech_embeddings: nn.ModuleList) -> torch.Tensor:
969
+ num_segments = audio_codes.shape[0]
970
+ out = torch.zeros(
971
+ (num_segments, self.group_size, self.input_local_dim),
972
+ dtype=speech_embeddings[0].weight.dtype,
973
+ device=audio_codes.device,
974
+ )
975
+ for i in range(self.audio_channels):
976
+ out.add_(speech_embeddings[i](audio_codes[:, :, i].long()))
977
+ return out
978
+
979
+ def _apply_input_local_transformer(self, speech_embeddings: torch.Tensor) -> torch.Tensor:
980
+ output = self.input_local_transformer(
981
+ inputs_embeds=speech_embeddings, return_dict=True, use_cache=False,
982
+ is_causal=not self.input_full_attention,
983
+ )
984
+ return output.last_hidden_state
985
+
986
+ def _process_audio_codes(self, audio_codes: torch.Tensor, speech_embeddings: nn.ModuleList) -> torch.Tensor:
987
+ audio_codes = _pad_and_group_audio_codes(audio_codes, self.audio_channels, self.group_size)
988
+ audio_embs = self._apply_speech_embeddings(audio_codes, speech_embeddings)
989
+ audio_hidden = self._apply_input_local_transformer(audio_embs)
990
+ return self.projection(audio_hidden.reshape(audio_hidden.shape[0], -1))
991
+
992
+ def get_audio_feature(
993
+ self,
994
+ mels: list[torch.Tensor],
995
+ speech_embeddings: nn.ModuleList,
996
+ audio_tokenizer_encoder,
997
+ ) -> torch.Tensor:
998
+ """Full pipeline: mel spectrograms → tokenize → codes → embed → project."""
999
+ if not mels:
1000
+ device = next(self.projection.parameters()).device
1001
+ dtype = next(self.projection.parameters()).dtype
1002
+ return torch.empty(0, self.out_hidden_size, device=device, dtype=dtype)
1003
+
1004
+ device = next(audio_tokenizer_encoder.parameters()).device
1005
+ code_list = tokenize_audio_batch(
1006
+ mels, audio_tokenizer_encoder, segment_size=self.audio_segment_size, device=device,
1007
+ )
1008
+
1009
+ codecs_to_concat = []
1010
+ for codecs in code_list:
1011
+ codecs_to_concat.append(_pad_and_group_audio_codes(codecs, self.audio_channels, self.group_size))
1012
+ audio_codes = torch.cat(codecs_to_concat, dim=0)
1013
+
1014
+ audio_embs = self._apply_speech_embeddings(audio_codes, speech_embeddings)
1015
+ audio_hidden = self._apply_input_local_transformer(audio_embs)
1016
+ return self.projection(audio_hidden.reshape(audio_hidden.shape[0], -1))
1017
+
1018
+ def forward(
1019
+ self,
1020
+ speech_embeddings: nn.ModuleList,
1021
+ audio_codes: torch.Tensor | None = None,
1022
+ audio_embeds: torch.Tensor | None = None,
1023
+ ) -> torch.Tensor:
1024
+ if audio_embeds is not None:
1025
+ if audio_embeds.dim() != 2:
1026
+ raise ValueError(f"`audio_embeds` must be 2D [N, H], got shape={tuple(audio_embeds.shape)}")
1027
+ if audio_embeds.shape[-1] != self.out_hidden_size:
1028
+ raise ValueError(
1029
+ f"Unexpected audio_embeds hidden size {audio_embeds.shape[-1]}, expected {self.out_hidden_size}"
1030
+ )
1031
+ return audio_embeds
1032
+
1033
+ if audio_codes is None:
1034
+ raise ValueError("Either `audio_codes` or `audio_embeds` must be provided.")
1035
+
1036
+ return self._process_audio_codes(audio_codes, speech_embeddings)
1037
+
1038
+
1039
+ # ---------------------------------------------------------------------------
1040
+ # Audio tokenizer (codec: mel → encoder → VQ → codes)
1041
+ # Adapted from https://github.com/XiaomiMiMo/MiMo-Audio-Tokenizer.git
1042
+ # ---------------------------------------------------------------------------
1043
+
1044
+
1045
+ class MiMoAudioTokenizerConfig(PretrainedConfig):
1046
+ model_type = "mimo_audio_tokenizer"
1047
+
1048
+ def __init__(
1049
+ self,
1050
+ max_audio_seconds: int = 1800,
1051
+ stride_size: int = 2,
1052
+ avg_pooler: int = 1,
1053
+ d_model: int = 768,
1054
+ scale_embedding: bool = True,
1055
+ kernel_size: int = 3,
1056
+ activation_function: str = "gelu",
1057
+ encoder_layers: int = 8,
1058
+ encoder_skip_layer_id: int = None,
1059
+ encoder_attention_heads: int = 12,
1060
+ encoder_ffn_dim: int = 3072,
1061
+ encoder_causal: bool = False,
1062
+ encoder_attn_window_size: list = None,
1063
+ decoder_layers: int = 8,
1064
+ decoder_attention_heads: int = 12,
1065
+ decoder_ffn_dim: int = 3072,
1066
+ decoder_kernel_size: int = 3,
1067
+ decoder_stride_size: int = 2,
1068
+ decoder_causal: bool = True,
1069
+ decoder_attn_window_size: list = None,
1070
+ nfft: int = 1024,
1071
+ vocoder_dim: int = 512,
1072
+ vocoder_intermediate_dim: int = 4096,
1073
+ vocoder_num_layers: int = 30,
1074
+ n_mels: int = 80,
1075
+ sampling_rate: int = 24000,
1076
+ hop_length: int = 240,
1077
+ window_size: int = 1024,
1078
+ vocoder_padding: str = "same",
1079
+ fmin: int = 0,
1080
+ fmax: int = None,
1081
+ num_quantizers: int = 12,
1082
+ codebook_size: list = None,
1083
+ threshold_ema_dead_code: int = 10,
1084
+ position_embedding_type: str = "rope",
1085
+ rope_theta: int = 10000,
1086
+ rope_type: str = "default",
1087
+ ln_type: str = "LayerNorm",
1088
+ vocoder_attention_heads: int = 4,
1089
+ vocoder_attn_window_size: list = None,
1090
+ use_istft_only: bool = False,
1091
+ hybrid_attention: bool = False,
1092
+ hybrid_block_size: int = 8,
1093
+ swa_per_block: int = 2,
1094
+ **kwargs,
1095
+ ):
1096
+ super().__init__(**kwargs)
1097
+ self.max_audio_seconds = max_audio_seconds
1098
+ self.stride_size = stride_size
1099
+ self.avg_pooler = avg_pooler
1100
+ self.d_model = d_model
1101
+ self.scale_embedding = scale_embedding
1102
+ self.kernel_size = kernel_size
1103
+ self.activation_function = activation_function
1104
+ self.encoder_layers = encoder_layers
1105
+ self.encoder_skip_layer_id = encoder_skip_layer_id
1106
+ self.encoder_attention_heads = encoder_attention_heads
1107
+ self.encoder_ffn_dim = encoder_ffn_dim
1108
+ self.encoder_causal = encoder_causal
1109
+ self.encoder_attn_window_size = encoder_attn_window_size if encoder_attn_window_size is not None else [-1, -1]
1110
+ self.decoder_layers = decoder_layers
1111
+ self.decoder_attention_heads = decoder_attention_heads
1112
+ self.decoder_ffn_dim = decoder_ffn_dim
1113
+ self.decoder_kernel_size = decoder_kernel_size
1114
+ self.decoder_stride_size = decoder_stride_size
1115
+ self.decoder_causal = decoder_causal
1116
+ self.decoder_attn_window_size = decoder_attn_window_size if decoder_attn_window_size is not None else [-1, -1]
1117
+ self.nfft = nfft
1118
+ self.vocoder_dim = vocoder_dim
1119
+ self.vocoder_intermediate_dim = vocoder_intermediate_dim
1120
+ self.vocoder_num_layers = vocoder_num_layers
1121
+ self.n_mels = n_mels
1122
+ self.sampling_rate = sampling_rate
1123
+ self.hop_length = hop_length
1124
+ self.window_size = window_size
1125
+ self.vocoder_padding = vocoder_padding
1126
+ self.fmin = fmin
1127
+ self.fmax = fmax
1128
+ self.num_quantizers = num_quantizers
1129
+ self.codebook_size = codebook_size if codebook_size is not None else [1024]
1130
+ self.threshold_ema_dead_code = threshold_ema_dead_code
1131
+ self.position_embedding_type = position_embedding_type
1132
+ self.rope_theta = rope_theta
1133
+ self.rope_type = rope_type
1134
+ self.ln_type = ln_type
1135
+ self.vocoder_attention_heads = vocoder_attention_heads
1136
+ self.vocoder_attn_window_size = vocoder_attn_window_size if vocoder_attn_window_size is not None else [40, 10]
1137
+ self.use_istft_only = use_istft_only
1138
+ self.hybrid_attention = hybrid_attention
1139
+ self.hybrid_block_size = hybrid_block_size
1140
+ self.swa_per_block = swa_per_block
1141
+
1142
+
1143
+ class EuclideanCodebook(nn.Module):
1144
+ def __init__(self, dim: int, codebook_size: int, kmeans_init: bool = False, **kwargs):
1145
+ super().__init__()
1146
+ init_fn = torch.zeros if kmeans_init else self._uniform_init
1147
+ embed = init_fn(codebook_size, dim)
1148
+ self.codebook_size = codebook_size
1149
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
1150
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
1151
+ self.register_buffer("embed", embed)
1152
+ self.register_buffer("embed_avg", embed.clone())
1153
+
1154
+ def quantize(self, x):
1155
+ embed = self.embed.t()
1156
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
1157
+ return dist.max(dim=-1).indices
1158
+
1159
+ def encode(self, x):
1160
+ shape = x.shape
1161
+ x = x.reshape(-1, x.shape[-1])
1162
+ embed_ind = self.quantize(x)
1163
+ return embed_ind.view(*shape[:-1])
1164
+
1165
+ def decode(self, embed_ind):
1166
+ return F.embedding(embed_ind, self.embed)
1167
+
1168
+ @staticmethod
1169
+ def _uniform_init(*shape: int):
1170
+ t = torch.empty(shape)
1171
+ nn.init.kaiming_uniform_(t)
1172
+ return t
1173
+
1174
+
1175
+ class VectorQuantization(nn.Module):
1176
+ def __init__(self, dim: int, codebook_size: int, codebook_dim: Optional[int] = None, kmeans_init: bool = True, **kwargs):
1177
+ super().__init__()
1178
+ _codebook_dim = codebook_dim if codebook_dim is not None else dim
1179
+ requires_projection = _codebook_dim != dim
1180
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
1181
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
1182
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init)
1183
+ self.codebook_size = codebook_size
1184
+
1185
+ def encode(self, x):
1186
+ return self._codebook.encode(self.project_in(x))
1187
+
1188
+ def decode(self, embed_ind):
1189
+ return self.project_out(self._codebook.decode(embed_ind))
1190
+
1191
+
1192
+ class ResidualVectorQuantization(nn.Module):
1193
+ def __init__(self, *, num_quantizers, codebook_size, **kwargs):
1194
+ super().__init__()
1195
+ if isinstance(codebook_size, int):
1196
+ codebook_size = [codebook_size] * num_quantizers
1197
+ elif len(codebook_size) < num_quantizers:
1198
+ codebook_size += [codebook_size[-1]] * (num_quantizers - len(codebook_size))
1199
+ self.layers = nn.ModuleList(
1200
+ [VectorQuantization(codebook_size=codebook_size[i], **kwargs) for i in range(num_quantizers)]
1201
+ )
1202
+
1203
+ def encode(self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None) -> torch.Tensor:
1204
+ residual = x
1205
+ all_indices = []
1206
+ n_q = len(self.layers) if n_q is None else n_q
1207
+ st = 0 if st is None else st
1208
+ for layer in self.layers[st:n_q]:
1209
+ indices = layer.encode(residual)
1210
+ quantized = layer.decode(indices)
1211
+ residual = residual - quantized
1212
+ all_indices.append(indices)
1213
+ return torch.stack(all_indices)
1214
+
1215
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
1216
+ quantized_out = self.layers[st].decode(q_indices[0])
1217
+ for i in range(1, len(q_indices)):
1218
+ quantized_out = quantized_out + self.layers[st + i].decode(q_indices[i])
1219
+ return quantized_out
1220
+
1221
+
1222
+ class ResidualVectorQuantizer(nn.Module):
1223
+ def __init__(self, dimension: int = 256, n_q: int = 8, bins: int | list = 1024, kmeans_init: bool = True, **kwargs):
1224
+ super().__init__()
1225
+ self.n_q = n_q
1226
+ self.vq = ResidualVectorQuantization(dim=dimension, codebook_size=bins, num_quantizers=n_q, kmeans_init=kmeans_init)
1227
+
1228
+ def encode(self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None) -> torch.Tensor:
1229
+ return self.vq.encode(x, n_q=n_q or self.n_q, st=st or 0)
1230
+
1231
+ def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
1232
+ return self.vq.decode(codes, st=st)
1233
+
1234
+
1235
+ class AudioTokenizerRotaryEmbedding(nn.Module):
1236
+ def __init__(self, base, dim, max_seq_len, rope_type="default", device=None):
1237
+ super().__init__()
1238
+ self.attention_scaling = 1.0
1239
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
1240
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1241
+
1242
+ @torch.no_grad()
1243
+ def forward(self, x, position_ids):
1244
+ inv_freq_expanded = self.inv_freq[:, None].float().expand(-1, 1).to(x.device)
1245
+ position_ids_expanded = position_ids[None, :].float()
1246
+ with torch.autocast(device_type="cpu", enabled=False):
1247
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(0, 1)
1248
+ emb = torch.cat((freqs, freqs), dim=-1)
1249
+ cos = emb.cos() * self.attention_scaling
1250
+ sin = emb.sin() * self.attention_scaling
1251
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1252
+
1253
+
1254
+ def _at_get_position_ids(lengths):
1255
+ total_len = lengths.sum()
1256
+ offset = torch.cat([torch.zeros(1, device=lengths.device, dtype=lengths.dtype), lengths[:-1].cumsum(dim=0)])
1257
+ offset = torch.repeat_interleave(offset, lengths)
1258
+ return torch.arange(0, total_len, device=lengths.device) - offset
1259
+
1260
+
1261
+ def _at_get_sequence_mask(inputs, inputs_length):
1262
+ if inputs.dim() == 3:
1263
+ bsz, tgt_len, _ = inputs.size()
1264
+ else:
1265
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
1266
+ sequence_mask = torch.arange(0, tgt_len, device=inputs.device)
1267
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
1268
+ unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
1269
+ return sequence_mask, unpacking_index
1270
+
1271
+
1272
+ def _at_unpack_hidden_states(hidden_states, lengths, sequence_mask=None, unpacking_index=None):
1273
+ bsz = lengths.shape[0]
1274
+ if sequence_mask is None or unpacking_index is None:
1275
+ sequence_mask, unpacking_index = _at_get_sequence_mask(hidden_states, lengths)
1276
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
1277
+ bsz, torch.max(lengths), hidden_states.shape[-1]
1278
+ )
1279
+ return torch.where(sequence_mask, hidden_states, 0)
1280
+
1281
+
1282
+ def _at_rotate_half(x):
1283
+ x1 = x[..., : x.shape[-1] // 2]
1284
+ x2 = x[..., x.shape[-1] // 2 :]
1285
+ return torch.cat((-x2, x1), dim=-1)
1286
+
1287
+
1288
+ def _at_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
1289
+ cos = cos.unsqueeze(unsqueeze_dim)
1290
+ sin = sin.unsqueeze(unsqueeze_dim)
1291
+ return (q * cos) + (_at_rotate_half(q) * sin), (k * cos) + (_at_rotate_half(k) * sin)
1292
+
1293
+
1294
+ _AT_LAYER_NORM = {"LayerNorm": nn.LayerNorm}
1295
+
1296
+
1297
+ class AudioTokenizerAttention(nn.Module):
1298
+ def __init__(self, embed_dim: int, num_heads: int, window_size: tuple[int, int] = (-1, -1), causal: bool = False):
1299
+ super().__init__()
1300
+ self.embed_dim = embed_dim
1301
+ self.num_heads = num_heads
1302
+ self.head_dim = embed_dim // num_heads
1303
+ self.window_size = window_size
1304
+ self.causal = causal
1305
+ self.scaling = self.head_dim**-0.5
1306
+
1307
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
1308
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
1309
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
1310
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
1311
+
1312
+ def _build_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
1313
+ has_window = self.window_size[0] > 0
1314
+ if not self.causal and not has_window:
1315
+ return None
1316
+ mask = torch.zeros(seq_len, seq_len, device=device, dtype=dtype)
1317
+ if self.causal:
1318
+ mask = mask + torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=dtype), diagonal=1)
1319
+ if has_window:
1320
+ row_idx = torch.arange(seq_len, device=device).unsqueeze(1)
1321
+ col_idx = torch.arange(seq_len, device=device).unsqueeze(0)
1322
+ mask = mask.masked_fill((row_idx - col_idx).abs() > self.window_size[0], float("-inf"))
1323
+ return mask
1324
+
1325
+ def forward(self, hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings=None):
1326
+ total_len = hidden_states.shape[0]
1327
+ q = self.q_proj(hidden_states).view(total_len, self.num_heads, self.head_dim)
1328
+ k = self.k_proj(hidden_states).view(total_len, self.num_heads, self.head_dim)
1329
+ v = self.v_proj(hidden_states).view(total_len, self.num_heads, self.head_dim)
1330
+ if rope_position_embeddings is not None:
1331
+ cos, sin = rope_position_embeddings
1332
+ q, k = _at_apply_rotary_pos_emb(q, k, cos, sin)
1333
+ num_seqs = cu_seqlens.shape[0] - 1
1334
+ outputs = []
1335
+ for i in range(num_seqs):
1336
+ start, end = cu_seqlens[i].item(), cu_seqlens[i + 1].item()
1337
+ seq_len = end - start
1338
+ q_seq = q[start:end].transpose(0, 1).unsqueeze(0)
1339
+ k_seq = k[start:end].transpose(0, 1).unsqueeze(0)
1340
+ v_seq = v[start:end].transpose(0, 1).unsqueeze(0)
1341
+ attn_mask = self._build_attn_mask(seq_len, q_seq.device, q_seq.dtype)
1342
+ out = F.scaled_dot_product_attention(q_seq, k_seq, v_seq, attn_mask=attn_mask, scale=self.scaling)
1343
+ outputs.append(out.squeeze(0).transpose(0, 1))
1344
+ return self.out_proj(torch.cat(outputs, dim=0).reshape(total_len, self.embed_dim))
1345
+
1346
+
1347
+ class AudioTokenizerTransformerLayer(nn.Module):
1348
+ def __init__(self, config: MiMoAudioTokenizerConfig, causal: bool, attn_window_size: tuple[int, int] = (-1, -1)):
1349
+ super().__init__()
1350
+ self.embed_dim = config.d_model
1351
+ self.self_attn = AudioTokenizerAttention(
1352
+ embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads,
1353
+ window_size=attn_window_size, causal=causal,
1354
+ )
1355
+ self.self_attn_layer_norm = _AT_LAYER_NORM[config.ln_type](self.embed_dim)
1356
+ self.activation_fn = ACT2FN[config.activation_function]
1357
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
1358
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
1359
+ self.final_layer_norm = _AT_LAYER_NORM[config.ln_type](self.embed_dim)
1360
+
1361
+ def forward(self, hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings):
1362
+ residual = hidden_states
1363
+ hidden_states = self.self_attn_layer_norm(hidden_states)
1364
+ hidden_states = self.self_attn(hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings=rope_position_embeddings)
1365
+ hidden_states = residual + hidden_states
1366
+ residual = hidden_states
1367
+ hidden_states = self.final_layer_norm(hidden_states)
1368
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
1369
+ hidden_states = self.fc2(hidden_states)
1370
+ hidden_states = residual + hidden_states
1371
+ return hidden_states
1372
+
1373
+
1374
+ class AudioTokenizerEncoder(nn.Module):
1375
+ def __init__(self, config: MiMoAudioTokenizerConfig):
1376
+ super().__init__()
1377
+ self.config = config
1378
+ self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
1379
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1380
+ self.skip_layer_idx = config.encoder_skip_layer_id
1381
+
1382
+ self.conv1 = nn.Conv1d(config.n_mels, config.d_model, kernel_size=config.kernel_size, padding=1)
1383
+ self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size, stride=config.stride_size, padding=1)
1384
+
1385
+ self.position_embedding = AudioTokenizerRotaryEmbedding(
1386
+ config.rope_theta, config.d_model // config.encoder_attention_heads,
1387
+ self.max_source_positions, config.rope_type,
1388
+ )
1389
+
1390
+ attn_window_sizes = []
1391
+ if config.hybrid_attention:
1392
+ for i in range(config.encoder_layers):
1393
+ if i % config.swa_per_block < config.swa_per_block - 1:
1394
+ attn_window_sizes.append(tuple(config.encoder_attn_window_size))
1395
+ else:
1396
+ attn_window_sizes.append((-1, -1))
1397
+ else:
1398
+ attn_window_sizes = [tuple(config.encoder_attn_window_size)] * config.encoder_layers
1399
+
1400
+ self.layers = nn.ModuleList([
1401
+ AudioTokenizerTransformerLayer(config=config, causal=config.encoder_causal, attn_window_size=attn_window_sizes[i])
1402
+ for i in range(config.encoder_layers)
1403
+ ])
1404
+
1405
+ self.layer_norm = _AT_LAYER_NORM[config.ln_type](config.d_model)
1406
+
1407
+ if config.avg_pooler != 1:
1408
+ self.down_sample_layer = nn.Sequential(
1409
+ nn.Conv1d(config.d_model, config.d_model, config.avg_pooler, config.avg_pooler, bias=False),
1410
+ nn.GELU(),
1411
+ )
1412
+ self.down_sample_norm = _AT_LAYER_NORM[config.ln_type](config.d_model)
1413
+ else:
1414
+ self.down_sample_layer = None
1415
+
1416
+ if config.num_quantizers != 0:
1417
+ self.quantizer = ResidualVectorQuantizer(
1418
+ dimension=config.d_model, n_q=config.num_quantizers,
1419
+ bins=config.codebook_size,
1420
+ threshold_ema_dead_code=config.threshold_ema_dead_code,
1421
+ )
1422
+ else:
1423
+ self.quantizer = None
1424
+
1425
+ def get_output_length(self, mel_len):
1426
+ tgt_len = mel_len + 3 - self.config.kernel_size
1427
+ return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
1428
+
1429
+ def get_features(self, input_features, output_length):
1430
+ input_features = input_features.to(self.conv1.weight)
1431
+ inputs_embeds = F.gelu(self.conv1(input_features))
1432
+ inputs_embeds = F.gelu(self.conv2(inputs_embeds))
1433
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
1434
+ bsz, tgt_len, _ = inputs_embeds.size()
1435
+
1436
+ position_ids = _at_get_position_ids(output_length).long().to(input_features.device)
1437
+ rope_position_embeddings = self.position_embedding(input_features, position_ids)
1438
+
1439
+ attention_mask, unpacking_index = _at_get_sequence_mask(inputs_embeds, output_length)
1440
+ hidden_states = torch.masked_select(inputs_embeds, attention_mask).view(
1441
+ torch.sum(output_length), self.config.d_model
1442
+ )
1443
+
1444
+ cu_seqlens = F.pad(torch.cumsum(output_length, dim=0), (1, 0), "constant", 0).to(
1445
+ device=hidden_states.device, dtype=torch.int32
1446
+ )
1447
+ max_seqlen = torch.max(output_length).to(torch.int32).item()
1448
+
1449
+ skip_connect_hidden_states = 0.0
1450
+ for idx, encoder_layer in enumerate(self.layers):
1451
+ hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings=rope_position_embeddings)
1452
+ if self.skip_layer_idx is not None and idx == self.skip_layer_idx - 1:
1453
+ skip_connect_hidden_states = hidden_states.clone()
1454
+
1455
+ hidden_states += skip_connect_hidden_states
1456
+ hidden_states = self.layer_norm(hidden_states)
1457
+
1458
+ if self.down_sample_layer is not None:
1459
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
1460
+ if hidden_states.size(1) % self.config.avg_pooler:
1461
+ pad_len = self.config.avg_pooler - hidden_states.size(1) % self.config.avg_pooler
1462
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len), mode="constant", value=0.0)
1463
+ tgt_len += pad_len
1464
+ tgt_len = tgt_len // self.config.avg_pooler
1465
+ hidden_states = self.down_sample_layer(hidden_states.transpose(1, 2))
1466
+ output_length = output_length // self.config.avg_pooler + (output_length % self.config.avg_pooler != 0).int()
1467
+ hidden_states = hidden_states.transpose(1, 2)
1468
+ attention_mask, unpacking_index = _at_get_sequence_mask(hidden_states, output_length)
1469
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
1470
+ torch.sum(output_length), self.config.d_model
1471
+ )
1472
+ hidden_states = self.down_sample_norm(hidden_states)
1473
+
1474
+ return hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz
1475
+
1476
+ @torch.no_grad()
1477
+ def encode(self, input_features, input_lens=None, output_length=None, return_codes_only=False, n_q=None, use_quantizer=True):
1478
+ if output_length is None:
1479
+ output_length = self.get_output_length(input_lens)
1480
+ input_features = _at_unpack_hidden_states(input_features, input_lens)
1481
+ hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz = self.get_features(
1482
+ input_features=input_features.transpose(1, 2), output_length=output_length,
1483
+ )
1484
+ dtype = hidden_states.dtype
1485
+ if use_quantizer and self.quantizer is not None:
1486
+ self.quantizer.float()
1487
+ codes = self.quantizer.encode(hidden_states.float(), n_q=n_q)
1488
+ if return_codes_only:
1489
+ return codes, output_length
1490
+ hidden_states = self.quantizer.decode(codes)
1491
+ hidden_states = hidden_states.to(dtype)
1492
+ else:
1493
+ codes = None
1494
+ hidden_states_packed = hidden_states.clone()
1495
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
1496
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
1497
+ return hidden_states, hidden_states_packed, output_length, codes
1498
+
1499
+
1500
+ class MiMoAudioTokenizer(PreTrainedModel):
1501
+ config_class = MiMoAudioTokenizerConfig
1502
+
1503
+ def __init__(self, config: MiMoAudioTokenizerConfig):
1504
+ super().__init__(config)
1505
+ self.config = config
1506
+ self.sampling_rate = config.sampling_rate
1507
+ self.encoder = AudioTokenizerEncoder(config=config)
1508
+ self.downsample_rate = int(config.hop_length * 2 * config.avg_pooler)
1509
+
1510
+ def get_output_length(self, mel_len):
1511
+ return self.encoder.get_output_length(mel_len)
1512
+
1513
+ @torch.no_grad()
1514
+ def encode(self, mels, input_lens, use_quantizer=True):
1515
+ return self.encoder.encode(mels, input_lens=input_lens, use_quantizer=use_quantizer)
1516
+
1517
+
1518
+ def _at_group_by_length(features, lengths, max_length):
1519
+ split_points, current_sum = [], 0
1520
+ for i, seq_len in enumerate(lengths):
1521
+ if current_sum + seq_len > max_length and current_sum > 0:
1522
+ split_points.append(i)
1523
+ current_sum = seq_len.item()
1524
+ else:
1525
+ current_sum += seq_len.item()
1526
+ group_sizes, prev = [], 0
1527
+ for point in split_points:
1528
+ group_sizes.append(point - prev)
1529
+ prev = point
1530
+ if prev < len(lengths):
1531
+ group_sizes.append(len(lengths) - prev)
1532
+ len_groups = torch.split(lengths, group_sizes)
1533
+ feature_groups = torch.split(features, [g.sum().item() for g in len_groups])
1534
+ return feature_groups, len_groups
1535
+
1536
+
1537
+ @torch.no_grad()
1538
+ def tokenize_audio_batch(mels, audio_tokenizer_encoder, segment_size=6000, device=None):
1539
+ if not mels:
1540
+ return []
1541
+ if device is None:
1542
+ device = next(audio_tokenizer_encoder.parameters()).device
1543
+ input_len_seg_per_mel = []
1544
+ for m in mels:
1545
+ input_len = m.size(0)
1546
+ segs = [segment_size] * (input_len // segment_size)
1547
+ if input_len % segment_size > 0:
1548
+ segs.append(input_len % segment_size)
1549
+ input_len_seg_per_mel.append(segs)
1550
+ input_lens_flat = [s for segs in input_len_seg_per_mel for s in segs]
1551
+ input_features = torch.cat([m.to(device) for m in mels], dim=0)
1552
+ input_lens_t = torch.tensor(input_lens_flat, dtype=torch.long, device=device)
1553
+ feature_groups, len_groups = _at_group_by_length(input_features, input_lens_t, 256000)
1554
+ encoded_parts = []
1555
+ for features, lengths in zip(feature_groups, len_groups):
1556
+ codes, _ = audio_tokenizer_encoder.encode(input_features=features, input_lens=lengths, return_codes_only=True)
1557
+ encoded_parts.append(codes)
1558
+ codes = torch.cat(encoded_parts, dim=-1).transpose(0, 1).detach()
1559
+ code_lengths = []
1560
+ for segs in input_len_seg_per_mel:
1561
+ out_len = audio_tokenizer_encoder.get_output_length(torch.tensor(segs, dtype=torch.long, device=device))
1562
+ if getattr(audio_tokenizer_encoder, "down_sample_layer", None) is not None:
1563
+ avg = audio_tokenizer_encoder.config.avg_pooler
1564
+ out_len = out_len // avg + (out_len % avg != 0).long()
1565
+ code_lengths.append(out_len.sum().item())
1566
+ return list(torch.split(codes, code_lengths))
1567
+
1568
+
1569
+ # ---------------------------------------------------------------------------
1570
+ # LLM backbone
1571
+ # ---------------------------------------------------------------------------
1572
+
1573
+
1574
+ class MiMoV2Model(PreTrainedModel):
1575
+ config_class = MiMoV2Config
1576
+ attention_projection_layout = "split"
1577
+
1578
+ def __init__(self, config):
1579
+ super().__init__(config)
1580
+ self.attention_projection_layout = getattr(
1581
+ config, "attention_projection_layout", self.attention_projection_layout
1582
+ )
1583
+ self.vocab_size = config.vocab_size
1584
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
1585
+ self.layers = nn.ModuleList(
1586
+ [
1587
+ MiMoV2DecoderLayer(
1588
+ config,
1589
+ layer_idx,
1590
+ attention_projection_layout=self.attention_projection_layout,
1591
+ )
1592
+ for layer_idx in range(config.num_hidden_layers)
1593
+ ]
1594
+ )
1595
+ self.norm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
1596
+ self.rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=False)
1597
+ self.swa_rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=True)
1598
+ self.has_sliding_layers = any(pattern == 1 for pattern in config.hybrid_layer_pattern)
1599
+ self.config.layer_types = [
1600
+ "sliding_attention" if config.hybrid_layer_pattern[i] == 1 else "full_attention"
1601
+ for i in range(config.num_hidden_layers)
1602
+ ]
1603
+ self.post_init()
1604
+
1605
+ def get_input_embeddings(self):
1606
+ return self.embed_tokens
1607
+
1608
+ def set_input_embeddings(self, value):
1609
+ self.embed_tokens = value
1610
+
1611
+ def forward(
1612
+ self,
1613
+ input_ids: Optional[torch.LongTensor] = None,
1614
+ attention_mask: Optional[torch.Tensor] = None,
1615
+ position_ids: Optional[torch.LongTensor] = None,
1616
+ past_key_values: Optional[Cache] = None,
1617
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1618
+ use_cache: Optional[bool] = None,
1619
+ cache_position: Optional[torch.LongTensor] = None,
1620
+ **kwargs: Unpack[TransformersKwargs],
1621
+ ) -> BaseModelOutputWithPast:
1622
+ if (input_ids is None) ^ (inputs_embeds is not None):
1623
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1624
+
1625
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1626
+
1627
+ if inputs_embeds is None:
1628
+ inputs_embeds = self.embed_tokens(input_ids)
1629
+
1630
+ if use_cache and past_key_values is None:
1631
+ past_key_values = DynamicCache(config=self.config)
1632
+
1633
+ if cache_position is None:
1634
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1635
+ cache_position = torch.arange(
1636
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1637
+ )
1638
+
1639
+ if position_ids is None:
1640
+ position_ids = cache_position.unsqueeze(0)
1641
+
1642
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
1643
+ mask_kwargs = {
1644
+ "config": self.config,
1645
+ "input_embeds": inputs_embeds,
1646
+ "attention_mask": attention_mask,
1647
+ "cache_position": cache_position,
1648
+ "past_key_values": past_key_values,
1649
+ "position_ids": position_ids,
1650
+ }
1651
+ causal_mask_mapping = {
1652
+ "full_attention": create_causal_mask(**mask_kwargs),
1653
+ }
1654
+ if self.has_sliding_layers:
1655
+ if getattr(self.config, "sliding_window", None) is None:
1656
+ raise ValueError("MiMoV2 config `sliding_window` must be set when hybrid_layer_pattern uses SWA.")
1657
+ causal_mask_mapping["sliding_window_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1658
+
1659
+ hidden_states = inputs_embeds
1660
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1661
+ swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids)
1662
+
1663
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
1664
+ hidden_states = decoder_layer(
1665
+ hidden_states,
1666
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
1667
+ position_embeddings=position_embeddings
1668
+ if decoder_layer.attention_type == "full_attention"
1669
+ else swa_position_embeddings,
1670
+ position_ids=position_ids,
1671
+ past_key_values=past_key_values,
1672
+ use_cache=use_cache,
1673
+ cache_position=cache_position,
1674
+ **kwargs,
1675
+ )
1676
+
1677
+ hidden_states = self.norm(hidden_states)
1678
+ return BaseModelOutputWithPast(
1679
+ last_hidden_state=hidden_states,
1680
+ past_key_values=past_key_values if use_cache else None,
1681
+ )
1682
+
1683
+
1684
+ class MiMoV2ForCausalLM(PreTrainedModel, GenerationMixin):
1685
+ config_class = MiMoV2Config
1686
+ model_class = MiMoV2Model
1687
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
1688
+ _tp_plan = {"lm_head": "colwise_rep"}
1689
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1690
+ _keys_to_ignore_on_load_unexpected = [
1691
+ r"model\.(swa_)?rotary_emb\.inv_freq",
1692
+ r"model\.layers\.\d+\.self_attn\.rotary_emb\.inv_freq",
1693
+ r"model\.layers\.\d+\.self_attn\.rotary_emb\.(cos_cached|sin_cached)",
1694
+ r"model\.mtp\..*",
1695
+ ]
1696
+ _keys_to_ignore_on_load_missing = [
1697
+ r"audio_encoder\.input_local_transformer\.embed_tokens\.weight",
1698
+ ]
1699
+
1700
+ def __init__(self, config):
1701
+ super().__init__(config)
1702
+ self.model = self.model_class(config)
1703
+ self.vocab_size = config.vocab_size
1704
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1705
+
1706
+ if config.vision_config:
1707
+ self.visual = MiMoVisionTransformer(_as_namespace(config.vision_config))
1708
+ if config.audio_config:
1709
+ audio_cfg = _as_namespace(config.audio_config)
1710
+ self.speech_embeddings = _build_speech_embeddings(audio_cfg)
1711
+ self.audio_encoder = MiMoAudioEncoder(audio_cfg)
1712
+
1713
+ self.audio_tokenizer = None
1714
+ self.post_init()
1715
+
1716
+ def load_audio_tokenizer(self, path: str, device: torch.device | str | None = None, dtype: torch.dtype = torch.bfloat16):
1717
+ """Load the audio tokenizer from a directory containing config.json and model.safetensors."""
1718
+ import json
1719
+ import os
1720
+
1721
+ from safetensors.torch import load_file
1722
+
1723
+ config_path = os.path.join(path, "config.json")
1724
+ with open(config_path) as f:
1725
+ config_dict = json.load(f)
1726
+ tokenizer_config = MiMoAudioTokenizerConfig(**config_dict)
1727
+ tokenizer_model = MiMoAudioTokenizer(tokenizer_config)
1728
+
1729
+ safetensors_path = os.path.join(path, "model.safetensors")
1730
+ bin_path = os.path.join(path, "pytorch_model.bin")
1731
+ if os.path.exists(safetensors_path):
1732
+ state_dict = load_file(safetensors_path, device="cpu")
1733
+ elif os.path.exists(bin_path):
1734
+ state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
1735
+ else:
1736
+ raise FileNotFoundError(f"No model weights found in {path}")
1737
+ tokenizer_model.load_state_dict(state_dict, strict=False)
1738
+
1739
+ if device is None:
1740
+ device = next(self.parameters()).device
1741
+ tokenizer_model = tokenizer_model.to(device=device, dtype=dtype)
1742
+ tokenizer_model.eval()
1743
+ tokenizer_model.requires_grad_(False)
1744
+ self.audio_tokenizer = tokenizer_model
1745
+
1746
+ def get_input_embeddings(self):
1747
+ return self.model.embed_tokens
1748
+
1749
+ def set_input_embeddings(self, value):
1750
+ self.model.embed_tokens = value
1751
+
1752
+ def get_output_embeddings(self):
1753
+ return self.lm_head
1754
+
1755
+ def set_output_embeddings(self, new_embeddings):
1756
+ self.lm_head = new_embeddings
1757
+
1758
+ def _get_multimodal_embeds(
1759
+ self,
1760
+ input_ids: torch.Tensor,
1761
+ inputs_embeds: torch.Tensor,
1762
+ pixel_values: Optional[torch.Tensor] = None,
1763
+ image_grid_thw: Optional[torch.Tensor] = None,
1764
+ image_embeds: Optional[torch.Tensor] = None,
1765
+ video_pixel_values: Optional[torch.Tensor] = None,
1766
+ video_grid_thw: Optional[torch.Tensor] = None,
1767
+ video_embeds: Optional[torch.Tensor] = None,
1768
+ audio_codes: Optional[torch.Tensor] = None,
1769
+ audio_embeds: Optional[torch.Tensor] = None,
1770
+ ) -> torch.Tensor:
1771
+ has_image = image_embeds is not None or pixel_values is not None
1772
+ has_video = video_embeds is not None or video_pixel_values is not None
1773
+ has_audio = audio_embeds is not None or audio_codes is not None
1774
+
1775
+ if not (has_image or has_video or has_audio):
1776
+ return inputs_embeds
1777
+
1778
+ inputs_embeds = inputs_embeds.clone()
1779
+
1780
+ if has_image:
1781
+ cur_image_embeds = image_embeds if image_embeds is not None else self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw)
1782
+ _replace_modal_embeddings_inplace(
1783
+ input_ids=input_ids, inputs_embeds=inputs_embeds,
1784
+ token_id=getattr(self.config, "image_token_id", None), modal_embeds=cur_image_embeds,
1785
+ )
1786
+
1787
+ if has_video:
1788
+ cur_video_embeds = video_embeds if video_embeds is not None else self.visual(pixel_values=video_pixel_values, grid_thw=video_grid_thw)
1789
+ _replace_modal_embeddings_inplace(
1790
+ input_ids=input_ids, inputs_embeds=inputs_embeds,
1791
+ token_id=getattr(self.config, "video_token_id", None), modal_embeds=cur_video_embeds,
1792
+ )
1793
+
1794
+ if has_audio:
1795
+ _replace_modal_embeddings_inplace(
1796
+ input_ids=input_ids, inputs_embeds=inputs_embeds,
1797
+ token_id=getattr(self.config, "audio_token_id", None),
1798
+ modal_embeds=self.audio_encoder(
1799
+ speech_embeddings=self.speech_embeddings, audio_codes=audio_codes, audio_embeds=audio_embeds,
1800
+ ),
1801
+ )
1802
+
1803
+ return inputs_embeds
1804
+
1805
+ @can_return_tuple
1806
+ def forward(
1807
+ self,
1808
+ input_ids: Optional[torch.LongTensor] = None,
1809
+ attention_mask: Optional[torch.Tensor] = None,
1810
+ position_ids: Optional[torch.LongTensor] = None,
1811
+ past_key_values: Optional[Cache] = None,
1812
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1813
+ labels: Optional[torch.LongTensor] = None,
1814
+ use_cache: Optional[bool] = None,
1815
+ cache_position: Optional[torch.LongTensor] = None,
1816
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1817
+ pixel_values: Optional[torch.Tensor] = None,
1818
+ image_grid_thw: Optional[torch.Tensor] = None,
1819
+ image_embeds: Optional[torch.Tensor] = None,
1820
+ video_pixel_values: Optional[torch.Tensor] = None,
1821
+ video_grid_thw: Optional[torch.Tensor] = None,
1822
+ video_embeds: Optional[torch.Tensor] = None,
1823
+ audio_codes: Optional[torch.Tensor] = None,
1824
+ audio_embeds: Optional[torch.Tensor] = None,
1825
+ **kwargs: Unpack[TransformersKwargs],
1826
+ ) -> CausalLMOutputWithPast:
1827
+ if inputs_embeds is None and input_ids is not None:
1828
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
1829
+ if any(x is not None for x in [pixel_values, image_embeds, video_pixel_values, video_embeds, audio_codes, audio_embeds]):
1830
+ inputs_embeds = self._get_multimodal_embeds(
1831
+ input_ids=input_ids, inputs_embeds=inputs_embeds,
1832
+ pixel_values=pixel_values, image_grid_thw=image_grid_thw, image_embeds=image_embeds,
1833
+ video_pixel_values=video_pixel_values, video_grid_thw=video_grid_thw, video_embeds=video_embeds,
1834
+ audio_codes=audio_codes, audio_embeds=audio_embeds,
1835
+ )
1836
+ input_ids = None
1837
+
1838
+ outputs: BaseModelOutputWithPast = self.model(
1839
+ input_ids=input_ids,
1840
+ attention_mask=attention_mask,
1841
+ position_ids=position_ids,
1842
+ past_key_values=past_key_values,
1843
+ inputs_embeds=inputs_embeds,
1844
+ use_cache=use_cache,
1845
+ cache_position=cache_position,
1846
+ **kwargs,
1847
+ )
1848
+
1849
+ hidden_states = outputs.last_hidden_state
1850
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1851
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1852
+
1853
+ loss = None
1854
+ if labels is not None:
1855
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1856
+
1857
+ return CausalLMOutputWithPast(
1858
+ loss=loss,
1859
+ logits=logits,
1860
+ past_key_values=outputs.past_key_values,
1861
+ hidden_states=outputs.hidden_states,
1862
+ attentions=outputs.attentions,
1863
+ )
1864
+
1865
+
1866
+ __all__ = [
1867
+ "MiMoAudioTokenizer",
1868
+ "MiMoAudioTokenizerConfig",
1869
+ "MiMoV2Attention",
1870
+ "MiMoV2DecoderLayer",
1871
+ "MiMoV2ForCausalLM",
1872
+ "MiMoV2MLP",
1873
+ "MiMoV2MoE",
1874
+ "MiMoV2MoEGate",
1875
+ "MiMoV2Model",
1876
+ "MiMoV2RMSNorm",
1877
+ "MiMoV2RotaryEmbedding",
1878
+ ]